term_guard/constraints/
foreign_key.rs

1//! Foreign key constraint validation for Term.
2//!
3//! This module provides foreign key validation capabilities for ensuring referential integrity
4//! between tables. Foreign keys ensure that values in a child table's column exist as primary
5//! keys in a parent table, preventing orphaned records and maintaining data consistency.
6//!
7//! # Examples
8//!
9//! ## Basic Foreign Key Validation
10//!
11//! ```rust
12//! use term_guard::constraints::ForeignKeyConstraint;
13//! use term_guard::core::{Check, Level};
14//!
15//! // Validate that all order.customer_id values exist in customers.id
16//! let constraint = ForeignKeyConstraint::new("orders.customer_id", "customers.id");
17//!
18//! let check = Check::builder("referential_integrity")
19//!     .level(Level::Error)
20//!     .with_constraint(constraint)
21//!     .build();
22//! ```
23//!
24//! ## Foreign Key with Null Handling
25//!
26//! ```rust
27//! use term_guard::constraints::ForeignKeyConstraint;
28//!
29//! // Allow null values in the foreign key column
30//! let constraint = ForeignKeyConstraint::new("orders.customer_id", "customers.id")
31//!     .allow_nulls(true);
32//! ```
33
34use crate::core::{Constraint, ConstraintResult, ConstraintStatus};
35use crate::error::{Result, TermError};
36use crate::security::SqlSecurity;
37use arrow::array::{Array, Int64Array, StringArray};
38use async_trait::async_trait;
39use datafusion::prelude::*;
40use serde::{Deserialize, Serialize};
41use tracing::{debug, instrument, warn};
42
43/// Foreign key constraint for validating referential integrity between tables.
44///
45/// This constraint ensures that all non-null values in a child table's foreign key column
46/// exist as values in the parent table's referenced column. It's essential for maintaining
47/// data consistency and preventing orphaned records.
48///
49/// The constraint supports:
50/// - Inner and left joins for different validation strategies
51/// - Null value handling (allow/disallow nulls in foreign key)
52/// - Custom error messages and violation reporting
53/// - Performance optimization through predicate pushdown
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct ForeignKeyConstraint {
56    /// Column in the child table (e.g., "orders.customer_id")
57    child_column: String,
58    /// Column in the parent table (e.g., "customers.id")
59    parent_column: String,
60    /// Whether to allow NULL values in the foreign key column
61    allow_nulls: bool,
62    /// Use left join strategy (faster for large tables with few violations)
63    use_left_join: bool,
64    /// Maximum number of violation examples to collect
65    max_violations_reported: usize,
66}
67
68impl ForeignKeyConstraint {
69    /// Create a new foreign key constraint.
70    ///
71    /// # Arguments
72    ///
73    /// * `child_column` - Column in child table containing foreign key values
74    /// * `parent_column` - Column in parent table containing referenced values
75    ///
76    /// # Examples
77    ///
78    /// ```rust
79    /// use term_guard::constraints::ForeignKeyConstraint;
80    ///
81    /// let fk = ForeignKeyConstraint::new("orders.customer_id", "customers.id");
82    /// ```
83    pub fn new(child_column: impl Into<String>, parent_column: impl Into<String>) -> Self {
84        Self {
85            child_column: child_column.into(),
86            parent_column: parent_column.into(),
87            allow_nulls: false,
88            use_left_join: true,
89            max_violations_reported: 100,
90        }
91    }
92
93    /// Set whether to allow NULL values in the foreign key column.
94    ///
95    /// When `true`, NULL values in the child column are considered valid.
96    /// When `false`, NULL values are treated as constraint violations.
97    pub fn allow_nulls(mut self, allow: bool) -> Self {
98        self.allow_nulls = allow;
99        self
100    }
101
102    /// Set the join strategy for validation.
103    ///
104    /// - `true` (default): Use LEFT JOIN strategy, better for tables with few violations
105    /// - `false`: Use NOT EXISTS strategy, better for tables with many violations
106    pub fn use_left_join(mut self, use_left_join: bool) -> Self {
107        self.use_left_join = use_left_join;
108        self
109    }
110
111    /// Set the maximum number of violation examples to report.
112    ///
113    /// Defaults to 100. Set to 0 to disable violation example collection.
114    pub fn max_violations_reported(mut self, max_violations: usize) -> Self {
115        self.max_violations_reported = max_violations;
116        self
117    }
118
119    /// Get the child column name
120    pub fn child_column(&self) -> &str {
121        &self.child_column
122    }
123
124    /// Get the parent column name  
125    pub fn parent_column(&self) -> &str {
126        &self.parent_column
127    }
128
129    /// Parse table and column from qualified column name (e.g., "orders.customer_id")
130    fn parse_qualified_column(&self, qualified_column: &str) -> Result<(String, String)> {
131        let parts: Vec<&str> = qualified_column.split('.').collect();
132        if parts.len() != 2 {
133            return Err(TermError::constraint_evaluation(
134                "foreign_key",
135                format!(
136                    "Foreign key column must be qualified (table.column): '{qualified_column}'"
137                ),
138            ));
139        }
140
141        let table = parts[0].to_string();
142        let column = parts[1].to_string();
143
144        // Validate SQL identifiers for security
145        SqlSecurity::validate_identifier(&table)?;
146        SqlSecurity::validate_identifier(&column)?;
147
148        Ok((table, column))
149    }
150
151    /// Generate SQL query for foreign key validation using LEFT JOIN strategy
152    fn generate_left_join_query(
153        &self,
154        child_table: &str,
155        child_col: &str,
156        parent_table: &str,
157        parent_col: &str,
158    ) -> Result<String> {
159        let null_condition = if self.allow_nulls {
160            format!("AND {child_table}.{child_col} IS NOT NULL")
161        } else {
162            String::new()
163        };
164
165        let sql = format!(
166            "SELECT 
167                COUNT(*) as total_violations,
168                COUNT(DISTINCT {child_table}.{child_col}) as unique_violations
169             FROM {child_table} 
170             LEFT JOIN {parent_table} ON {child_table}.{child_col} = {parent_table}.{parent_col}
171             WHERE {parent_table}.{parent_col} IS NULL {null_condition}"
172        );
173
174        debug!("Generated foreign key validation query: {}", sql);
175        Ok(sql)
176    }
177
178    /// Generate SQL query to get violation examples
179    fn generate_violations_query(
180        &self,
181        child_table: &str,
182        child_col: &str,
183        parent_table: &str,
184        parent_col: &str,
185    ) -> Result<String> {
186        if self.max_violations_reported == 0 {
187            return Ok(String::new());
188        }
189
190        let null_condition = if self.allow_nulls {
191            format!("AND {child_table}.{child_col} IS NOT NULL")
192        } else {
193            String::new()
194        };
195
196        let limit = self.max_violations_reported;
197        let sql = format!(
198            "SELECT DISTINCT {child_table}.{child_col} as violating_value
199             FROM {child_table}
200             LEFT JOIN {parent_table} ON {child_table}.{child_col} = {parent_table}.{parent_col}
201             WHERE {parent_table}.{parent_col} IS NULL {null_condition}
202             LIMIT {limit}"
203        );
204
205        debug!("Generated violations query: {}", sql);
206        Ok(sql)
207    }
208
209    /// Collect violation examples using memory-efficient approach.
210    ///
211    /// This method limits memory usage by:
212    /// 1. Using LIMIT in the SQL query to restrict result size at the database level
213    /// 2. Pre-allocating vector with known maximum size
214    /// 3. Processing results in a single pass without intermediate collections
215    /// 4. Early termination when max violations are reached
216    async fn collect_violation_examples_efficiently(
217        &self,
218        ctx: &SessionContext,
219        child_table: &str,
220        child_col: &str,
221        parent_table: &str,
222        parent_col: &str,
223    ) -> Result<Vec<String>> {
224        if self.max_violations_reported == 0 {
225            return Ok(Vec::new());
226        }
227
228        let violations_sql =
229            self.generate_violations_query(child_table, child_col, parent_table, parent_col)?;
230        if violations_sql.is_empty() {
231            return Ok(Vec::new());
232        }
233
234        debug!("Executing foreign key violations query with memory-efficient collection");
235
236        let violations_df = ctx.sql(&violations_sql).await.map_err(|e| {
237            TermError::constraint_evaluation(
238                "foreign_key",
239                format!("Failed to execute violations query: {e}"),
240            )
241        })?;
242
243        let batches = violations_df.collect().await.map_err(|e| {
244            TermError::constraint_evaluation(
245                "foreign_key",
246                format!("Failed to collect violation examples: {e}"),
247            )
248        })?;
249
250        // Pre-allocate with known maximum size to avoid reallocations
251        let mut violation_examples = Vec::with_capacity(self.max_violations_reported);
252
253        // Process batches efficiently with early termination
254        for batch in batches {
255            for i in 0..batch.num_rows() {
256                if violation_examples.len() >= self.max_violations_reported {
257                    debug!(
258                        "Reached max violations limit ({}), stopping collection",
259                        self.max_violations_reported
260                    );
261                    return Ok(violation_examples);
262                }
263
264                // Handle different data types efficiently
265                if let Some(string_array) = batch.column(0).as_any().downcast_ref::<StringArray>() {
266                    if !string_array.is_null(i) {
267                        violation_examples.push(string_array.value(i).to_string());
268                    }
269                } else if let Some(int64_array) =
270                    batch.column(0).as_any().downcast_ref::<Int64Array>()
271                {
272                    if !int64_array.is_null(i) {
273                        violation_examples.push(int64_array.value(i).to_string());
274                    }
275                } else if let Some(float64_array) = batch
276                    .column(0)
277                    .as_any()
278                    .downcast_ref::<arrow::array::Float64Array>()
279                {
280                    if !float64_array.is_null(i) {
281                        violation_examples.push(float64_array.value(i).to_string());
282                    }
283                } else if let Some(int32_array) = batch
284                    .column(0)
285                    .as_any()
286                    .downcast_ref::<arrow::array::Int32Array>()
287                {
288                    if !int32_array.is_null(i) {
289                        violation_examples.push(int32_array.value(i).to_string());
290                    }
291                }
292                // Add more data types as needed for broader compatibility
293            }
294        }
295
296        debug!(
297            "Collected {} foreign key violation examples",
298            violation_examples.len()
299        );
300        Ok(violation_examples)
301    }
302}
303
304#[async_trait]
305impl Constraint for ForeignKeyConstraint {
306    #[instrument(skip(self, ctx), fields(constraint = "foreign_key"))]
307    async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
308        debug!(
309            "Evaluating foreign key constraint: {} -> {}",
310            self.child_column, self.parent_column
311        );
312
313        // Parse qualified column names
314        let (child_table, child_col) = self.parse_qualified_column(&self.child_column)?;
315        let (parent_table, parent_col) = self.parse_qualified_column(&self.parent_column)?;
316
317        // Generate and execute validation query
318        let sql =
319            self.generate_left_join_query(&child_table, &child_col, &parent_table, &parent_col)?;
320        let df = ctx.sql(&sql).await.map_err(|e| {
321            TermError::constraint_evaluation(
322                "foreign_key",
323                format!("Foreign key validation query failed: {e}"),
324            )
325        })?;
326
327        let batches = df.collect().await.map_err(|e| {
328            TermError::constraint_evaluation(
329                "foreign_key",
330                format!("Failed to collect foreign key results: {e}"),
331            )
332        })?;
333
334        if batches.is_empty() || batches[0].num_rows() == 0 {
335            return Ok(ConstraintResult::success());
336        }
337
338        // Extract violation counts
339        let batch = &batches[0];
340        let total_violations = batch
341            .column(0)
342            .as_any()
343            .downcast_ref::<Int64Array>()
344            .ok_or_else(|| {
345                TermError::constraint_evaluation(
346                    "foreign_key",
347                    "Invalid total violations column type",
348                )
349            })?
350            .value(0);
351
352        let unique_violations = batch
353            .column(1)
354            .as_any()
355            .downcast_ref::<Int64Array>()
356            .ok_or_else(|| {
357                TermError::constraint_evaluation(
358                    "foreign_key",
359                    "Invalid unique violations column type",
360                )
361            })?
362            .value(0);
363
364        if total_violations == 0 {
365            debug!("Foreign key constraint passed: no violations found");
366            return Ok(ConstraintResult::success());
367        }
368
369        // Collect violation examples using memory-efficient approach
370        let violation_examples = self
371            .collect_violation_examples_efficiently(
372                ctx,
373                &child_table,
374                &child_col,
375                &parent_table,
376                &parent_col,
377            )
378            .await?;
379
380        // Format error message
381        let message = if violation_examples.is_empty() {
382            format!(
383                "Foreign key constraint violation: {total_violations} values in '{}' do not exist in '{}' (total: {total_violations}, unique: {unique_violations})",
384                self.child_column, self.parent_column
385            )
386        } else {
387            let examples_str = if violation_examples.len() <= 5 {
388                violation_examples.join(", ")
389            } else {
390                format!(
391                    "{}, ... ({} more)",
392                    violation_examples[..5].join(", "),
393                    violation_examples.len() - 5
394                )
395            };
396
397            format!(
398                "Foreign key constraint violation: {total_violations} values in '{}' do not exist in '{}' (total: {total_violations}, unique: {unique_violations}). Examples: [{examples_str}]",
399                self.child_column, self.parent_column
400            )
401        };
402
403        warn!("{}", message);
404
405        Ok(ConstraintResult {
406            status: ConstraintStatus::Failure,
407            metric: Some(total_violations as f64),
408            message: Some(message),
409        })
410    }
411
412    fn name(&self) -> &str {
413        "foreign_key"
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420    use crate::test_utils::create_test_context;
421
422    #[tokio::test]
423    async fn test_foreign_key_constraint_success() -> Result<()> {
424        let ctx = create_test_context().await?;
425
426        // Create test tables with valid foreign keys only
427        ctx.sql("CREATE TABLE customers_success (id BIGINT, name STRING)")
428            .await?
429            .collect()
430            .await?;
431        ctx.sql("INSERT INTO customers_success VALUES (1, 'Alice'), (2, 'Bob')")
432            .await?
433            .collect()
434            .await?;
435        ctx.sql("CREATE TABLE orders_success (id BIGINT, customer_id BIGINT, amount DOUBLE)")
436            .await?
437            .collect()
438            .await?;
439        ctx.sql("INSERT INTO orders_success VALUES (1, 1, 100.0), (2, 2, 200.0)")
440            .await?
441            .collect()
442            .await?;
443
444        let constraint =
445            ForeignKeyConstraint::new("orders_success.customer_id", "customers_success.id");
446        let result = constraint.evaluate(&ctx).await?;
447
448        assert_eq!(result.status, ConstraintStatus::Success);
449        assert!(result.message.is_none());
450
451        Ok(())
452    }
453
454    #[tokio::test]
455    async fn test_foreign_key_constraint_violation() -> Result<()> {
456        let ctx = create_test_context().await?;
457
458        // Create test tables with foreign key violations
459        ctx.sql("CREATE TABLE customers_violation (id BIGINT, name STRING)")
460            .await?
461            .collect()
462            .await?;
463        ctx.sql("INSERT INTO customers_violation VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Charlie')")
464            .await?
465            .collect()
466            .await?;
467        ctx.sql("CREATE TABLE orders_violation (id BIGINT, customer_id BIGINT, amount DOUBLE)")
468            .await?
469            .collect()
470            .await?;
471        ctx.sql("INSERT INTO orders_violation VALUES (1, 1, 100.0), (2, 2, 200.0), (3, 999, 300.0), (4, 998, 400.0)")
472            .await?
473            .collect()
474            .await?;
475
476        let constraint =
477            ForeignKeyConstraint::new("orders_violation.customer_id", "customers_violation.id");
478        let result = constraint.evaluate(&ctx).await?;
479
480        assert_eq!(result.status, ConstraintStatus::Failure);
481        assert!(result.message.is_some());
482        assert_eq!(result.metric, Some(2.0)); // 2 violations
483
484        let message = result.message.unwrap();
485        assert!(message.contains("Foreign key constraint violation"));
486        assert!(message.contains("2 values"));
487        assert!(message.contains("orders_violation.customer_id"));
488        assert!(message.contains("customers_violation.id"));
489
490        Ok(())
491    }
492
493    #[tokio::test]
494    async fn test_foreign_key_with_nulls_disallowed() -> Result<()> {
495        let ctx = create_test_context().await?;
496
497        ctx.sql("CREATE TABLE customers_nulls_disallowed (id BIGINT, name STRING)")
498            .await?
499            .collect()
500            .await?;
501        ctx.sql("INSERT INTO customers_nulls_disallowed VALUES (1, 'Alice')")
502            .await?
503            .collect()
504            .await?;
505        ctx.sql(
506            "CREATE TABLE orders_nulls_disallowed (id BIGINT, customer_id BIGINT, amount DOUBLE)",
507        )
508        .await?
509        .collect()
510        .await?;
511        ctx.sql("INSERT INTO orders_nulls_disallowed VALUES (1, 1, 100.0), (2, NULL, 200.0)")
512            .await?
513            .collect()
514            .await?;
515
516        let constraint = ForeignKeyConstraint::new(
517            "orders_nulls_disallowed.customer_id",
518            "customers_nulls_disallowed.id",
519        )
520        .allow_nulls(false);
521        let result = constraint.evaluate(&ctx).await?;
522
523        // Should fail because null is not allowed
524        assert_eq!(result.status, ConstraintStatus::Failure);
525
526        Ok(())
527    }
528
529    #[tokio::test]
530    async fn test_foreign_key_with_nulls_allowed() -> Result<()> {
531        let ctx = create_test_context().await?;
532
533        ctx.sql("CREATE TABLE customers_nulls_allowed (id BIGINT, name STRING)")
534            .await?
535            .collect()
536            .await?;
537        ctx.sql("INSERT INTO customers_nulls_allowed VALUES (1, 'Alice')")
538            .await?
539            .collect()
540            .await?;
541        ctx.sql("CREATE TABLE orders_nulls_allowed (id BIGINT, customer_id BIGINT, amount DOUBLE)")
542            .await?
543            .collect()
544            .await?;
545        ctx.sql("INSERT INTO orders_nulls_allowed VALUES (1, 1, 100.0), (2, NULL, 200.0)")
546            .await?
547            .collect()
548            .await?;
549
550        let constraint = ForeignKeyConstraint::new(
551            "orders_nulls_allowed.customer_id",
552            "customers_nulls_allowed.id",
553        )
554        .allow_nulls(true);
555        let result = constraint.evaluate(&ctx).await?;
556
557        // Should succeed because null is allowed
558        assert_eq!(result.status, ConstraintStatus::Success);
559
560        Ok(())
561    }
562
563    #[test]
564    fn test_parse_qualified_column() {
565        let constraint = ForeignKeyConstraint::new("orders.customer_id", "customers.id");
566
567        let (table, column) = constraint
568            .parse_qualified_column("orders.customer_id")
569            .unwrap();
570        assert_eq!(table, "orders");
571        assert_eq!(column, "customer_id");
572
573        // Test invalid format
574        assert!(constraint.parse_qualified_column("invalid_column").is_err());
575        assert!(constraint.parse_qualified_column("too.many.parts").is_err());
576    }
577
578    #[test]
579    fn test_constraint_configuration() {
580        let constraint = ForeignKeyConstraint::new("orders.customer_id", "customers.id")
581            .allow_nulls(true)
582            .use_left_join(false)
583            .max_violations_reported(50);
584
585        assert_eq!(constraint.child_column(), "orders.customer_id");
586        assert_eq!(constraint.parent_column(), "customers.id");
587        assert!(constraint.allow_nulls);
588        assert!(!constraint.use_left_join);
589        assert_eq!(constraint.max_violations_reported, 50);
590    }
591
592    #[test]
593    fn test_constraint_name() {
594        let constraint = ForeignKeyConstraint::new("orders.customer_id", "customers.id");
595        assert_eq!(constraint.name(), "foreign_key");
596    }
597
598    #[test]
599    fn test_sql_generation() -> Result<()> {
600        let constraint = ForeignKeyConstraint::new("orders.customer_id", "customers.id");
601        let sql =
602            constraint.generate_left_join_query("orders", "customer_id", "customers", "id")?;
603
604        assert!(sql.contains("LEFT JOIN"));
605        assert!(sql.contains("orders.customer_id = customers.id"));
606        assert!(sql.contains("customers.id IS NULL"));
607        assert!(sql.contains("COUNT(*) as total_violations"));
608
609        Ok(())
610    }
611
612    #[test]
613    fn test_sql_generation_with_nulls_allowed() -> Result<()> {
614        let constraint =
615            ForeignKeyConstraint::new("orders.customer_id", "customers.id").allow_nulls(true);
616        let sql =
617            constraint.generate_left_join_query("orders", "customer_id", "customers", "id")?;
618
619        assert!(sql.contains("AND orders.customer_id IS NOT NULL"));
620
621        Ok(())
622    }
623}