1use crate::core::{Constraint, ConstraintResult, ConstraintStatus};
35use crate::error::{Result, TermError};
36use crate::security::SqlSecurity;
37use arrow::array::{Array, Int64Array, StringArray};
38use async_trait::async_trait;
39use datafusion::prelude::*;
40use serde::{Deserialize, Serialize};
41use tracing::{debug, instrument, warn};
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct ForeignKeyConstraint {
56 child_column: String,
58 parent_column: String,
60 allow_nulls: bool,
62 use_left_join: bool,
64 max_violations_reported: usize,
66}
67
68impl ForeignKeyConstraint {
69 pub fn new(child_column: impl Into<String>, parent_column: impl Into<String>) -> Self {
84 Self {
85 child_column: child_column.into(),
86 parent_column: parent_column.into(),
87 allow_nulls: false,
88 use_left_join: true,
89 max_violations_reported: 100,
90 }
91 }
92
93 pub fn allow_nulls(mut self, allow: bool) -> Self {
98 self.allow_nulls = allow;
99 self
100 }
101
102 pub fn use_left_join(mut self, use_left_join: bool) -> Self {
107 self.use_left_join = use_left_join;
108 self
109 }
110
111 pub fn max_violations_reported(mut self, max_violations: usize) -> Self {
115 self.max_violations_reported = max_violations;
116 self
117 }
118
119 pub fn child_column(&self) -> &str {
121 &self.child_column
122 }
123
124 pub fn parent_column(&self) -> &str {
126 &self.parent_column
127 }
128
129 fn parse_qualified_column(&self, qualified_column: &str) -> Result<(String, String)> {
131 let parts: Vec<&str> = qualified_column.split('.').collect();
132 if parts.len() != 2 {
133 return Err(TermError::constraint_evaluation(
134 "foreign_key",
135 format!(
136 "Foreign key column must be qualified (table.column): '{qualified_column}'"
137 ),
138 ));
139 }
140
141 let table = parts[0].to_string();
142 let column = parts[1].to_string();
143
144 SqlSecurity::validate_identifier(&table)?;
146 SqlSecurity::validate_identifier(&column)?;
147
148 Ok((table, column))
149 }
150
151 fn generate_left_join_query(
153 &self,
154 child_table: &str,
155 child_col: &str,
156 parent_table: &str,
157 parent_col: &str,
158 ) -> Result<String> {
159 let null_condition = if self.allow_nulls {
160 format!("AND {child_table}.{child_col} IS NOT NULL")
161 } else {
162 String::new()
163 };
164
165 let sql = format!(
166 "SELECT
167 COUNT(*) as total_violations,
168 COUNT(DISTINCT {child_table}.{child_col}) as unique_violations
169 FROM {child_table}
170 LEFT JOIN {parent_table} ON {child_table}.{child_col} = {parent_table}.{parent_col}
171 WHERE {parent_table}.{parent_col} IS NULL {null_condition}"
172 );
173
174 debug!("Generated foreign key validation query: {}", sql);
175 Ok(sql)
176 }
177
178 fn generate_violations_query(
180 &self,
181 child_table: &str,
182 child_col: &str,
183 parent_table: &str,
184 parent_col: &str,
185 ) -> Result<String> {
186 if self.max_violations_reported == 0 {
187 return Ok(String::new());
188 }
189
190 let null_condition = if self.allow_nulls {
191 format!("AND {child_table}.{child_col} IS NOT NULL")
192 } else {
193 String::new()
194 };
195
196 let limit = self.max_violations_reported;
197 let sql = format!(
198 "SELECT DISTINCT {child_table}.{child_col} as violating_value
199 FROM {child_table}
200 LEFT JOIN {parent_table} ON {child_table}.{child_col} = {parent_table}.{parent_col}
201 WHERE {parent_table}.{parent_col} IS NULL {null_condition}
202 LIMIT {limit}"
203 );
204
205 debug!("Generated violations query: {}", sql);
206 Ok(sql)
207 }
208
209 async fn collect_violation_examples_efficiently(
217 &self,
218 ctx: &SessionContext,
219 child_table: &str,
220 child_col: &str,
221 parent_table: &str,
222 parent_col: &str,
223 ) -> Result<Vec<String>> {
224 if self.max_violations_reported == 0 {
225 return Ok(Vec::new());
226 }
227
228 let violations_sql =
229 self.generate_violations_query(child_table, child_col, parent_table, parent_col)?;
230 if violations_sql.is_empty() {
231 return Ok(Vec::new());
232 }
233
234 debug!("Executing foreign key violations query with memory-efficient collection");
235
236 let violations_df = ctx.sql(&violations_sql).await.map_err(|e| {
237 TermError::constraint_evaluation(
238 "foreign_key",
239 format!("Failed to execute violations query: {e}"),
240 )
241 })?;
242
243 let batches = violations_df.collect().await.map_err(|e| {
244 TermError::constraint_evaluation(
245 "foreign_key",
246 format!("Failed to collect violation examples: {e}"),
247 )
248 })?;
249
250 let mut violation_examples = Vec::with_capacity(self.max_violations_reported);
252
253 for batch in batches {
255 for i in 0..batch.num_rows() {
256 if violation_examples.len() >= self.max_violations_reported {
257 debug!(
258 "Reached max violations limit ({}), stopping collection",
259 self.max_violations_reported
260 );
261 return Ok(violation_examples);
262 }
263
264 if let Some(string_array) = batch.column(0).as_any().downcast_ref::<StringArray>() {
266 if !string_array.is_null(i) {
267 violation_examples.push(string_array.value(i).to_string());
268 }
269 } else if let Some(int64_array) =
270 batch.column(0).as_any().downcast_ref::<Int64Array>()
271 {
272 if !int64_array.is_null(i) {
273 violation_examples.push(int64_array.value(i).to_string());
274 }
275 } else if let Some(float64_array) = batch
276 .column(0)
277 .as_any()
278 .downcast_ref::<arrow::array::Float64Array>()
279 {
280 if !float64_array.is_null(i) {
281 violation_examples.push(float64_array.value(i).to_string());
282 }
283 } else if let Some(int32_array) = batch
284 .column(0)
285 .as_any()
286 .downcast_ref::<arrow::array::Int32Array>()
287 {
288 if !int32_array.is_null(i) {
289 violation_examples.push(int32_array.value(i).to_string());
290 }
291 }
292 }
294 }
295
296 debug!(
297 "Collected {} foreign key violation examples",
298 violation_examples.len()
299 );
300 Ok(violation_examples)
301 }
302}
303
304#[async_trait]
305impl Constraint for ForeignKeyConstraint {
306 #[instrument(skip(self, ctx), fields(constraint = "foreign_key"))]
307 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
308 debug!(
309 "Evaluating foreign key constraint: {} -> {}",
310 self.child_column, self.parent_column
311 );
312
313 let (child_table, child_col) = self.parse_qualified_column(&self.child_column)?;
315 let (parent_table, parent_col) = self.parse_qualified_column(&self.parent_column)?;
316
317 let sql =
319 self.generate_left_join_query(&child_table, &child_col, &parent_table, &parent_col)?;
320 let df = ctx.sql(&sql).await.map_err(|e| {
321 TermError::constraint_evaluation(
322 "foreign_key",
323 format!("Foreign key validation query failed: {e}"),
324 )
325 })?;
326
327 let batches = df.collect().await.map_err(|e| {
328 TermError::constraint_evaluation(
329 "foreign_key",
330 format!("Failed to collect foreign key results: {e}"),
331 )
332 })?;
333
334 if batches.is_empty() || batches[0].num_rows() == 0 {
335 return Ok(ConstraintResult::success());
336 }
337
338 let batch = &batches[0];
340 let total_violations = batch
341 .column(0)
342 .as_any()
343 .downcast_ref::<Int64Array>()
344 .ok_or_else(|| {
345 TermError::constraint_evaluation(
346 "foreign_key",
347 "Invalid total violations column type",
348 )
349 })?
350 .value(0);
351
352 let unique_violations = batch
353 .column(1)
354 .as_any()
355 .downcast_ref::<Int64Array>()
356 .ok_or_else(|| {
357 TermError::constraint_evaluation(
358 "foreign_key",
359 "Invalid unique violations column type",
360 )
361 })?
362 .value(0);
363
364 if total_violations == 0 {
365 debug!("Foreign key constraint passed: no violations found");
366 return Ok(ConstraintResult::success());
367 }
368
369 let violation_examples = self
371 .collect_violation_examples_efficiently(
372 ctx,
373 &child_table,
374 &child_col,
375 &parent_table,
376 &parent_col,
377 )
378 .await?;
379
380 let message = if violation_examples.is_empty() {
382 format!(
383 "Foreign key constraint violation: {total_violations} values in '{}' do not exist in '{}' (total: {total_violations}, unique: {unique_violations})",
384 self.child_column, self.parent_column
385 )
386 } else {
387 let examples_str = if violation_examples.len() <= 5 {
388 violation_examples.join(", ")
389 } else {
390 format!(
391 "{}, ... ({} more)",
392 violation_examples[..5].join(", "),
393 violation_examples.len() - 5
394 )
395 };
396
397 format!(
398 "Foreign key constraint violation: {total_violations} values in '{}' do not exist in '{}' (total: {total_violations}, unique: {unique_violations}). Examples: [{examples_str}]",
399 self.child_column, self.parent_column
400 )
401 };
402
403 warn!("{}", message);
404
405 Ok(ConstraintResult {
406 status: ConstraintStatus::Failure,
407 metric: Some(total_violations as f64),
408 message: Some(message),
409 })
410 }
411
412 fn name(&self) -> &str {
413 "foreign_key"
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420 use crate::test_utils::create_test_context;
421
422 #[tokio::test]
423 async fn test_foreign_key_constraint_success() -> Result<()> {
424 let ctx = create_test_context().await?;
425
426 ctx.sql("CREATE TABLE customers_success (id BIGINT, name STRING)")
428 .await?
429 .collect()
430 .await?;
431 ctx.sql("INSERT INTO customers_success VALUES (1, 'Alice'), (2, 'Bob')")
432 .await?
433 .collect()
434 .await?;
435 ctx.sql("CREATE TABLE orders_success (id BIGINT, customer_id BIGINT, amount DOUBLE)")
436 .await?
437 .collect()
438 .await?;
439 ctx.sql("INSERT INTO orders_success VALUES (1, 1, 100.0), (2, 2, 200.0)")
440 .await?
441 .collect()
442 .await?;
443
444 let constraint =
445 ForeignKeyConstraint::new("orders_success.customer_id", "customers_success.id");
446 let result = constraint.evaluate(&ctx).await?;
447
448 assert_eq!(result.status, ConstraintStatus::Success);
449 assert!(result.message.is_none());
450
451 Ok(())
452 }
453
454 #[tokio::test]
455 async fn test_foreign_key_constraint_violation() -> Result<()> {
456 let ctx = create_test_context().await?;
457
458 ctx.sql("CREATE TABLE customers_violation (id BIGINT, name STRING)")
460 .await?
461 .collect()
462 .await?;
463 ctx.sql("INSERT INTO customers_violation VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Charlie')")
464 .await?
465 .collect()
466 .await?;
467 ctx.sql("CREATE TABLE orders_violation (id BIGINT, customer_id BIGINT, amount DOUBLE)")
468 .await?
469 .collect()
470 .await?;
471 ctx.sql("INSERT INTO orders_violation VALUES (1, 1, 100.0), (2, 2, 200.0), (3, 999, 300.0), (4, 998, 400.0)")
472 .await?
473 .collect()
474 .await?;
475
476 let constraint =
477 ForeignKeyConstraint::new("orders_violation.customer_id", "customers_violation.id");
478 let result = constraint.evaluate(&ctx).await?;
479
480 assert_eq!(result.status, ConstraintStatus::Failure);
481 assert!(result.message.is_some());
482 assert_eq!(result.metric, Some(2.0)); let message = result.message.unwrap();
485 assert!(message.contains("Foreign key constraint violation"));
486 assert!(message.contains("2 values"));
487 assert!(message.contains("orders_violation.customer_id"));
488 assert!(message.contains("customers_violation.id"));
489
490 Ok(())
491 }
492
493 #[tokio::test]
494 async fn test_foreign_key_with_nulls_disallowed() -> Result<()> {
495 let ctx = create_test_context().await?;
496
497 ctx.sql("CREATE TABLE customers_nulls_disallowed (id BIGINT, name STRING)")
498 .await?
499 .collect()
500 .await?;
501 ctx.sql("INSERT INTO customers_nulls_disallowed VALUES (1, 'Alice')")
502 .await?
503 .collect()
504 .await?;
505 ctx.sql(
506 "CREATE TABLE orders_nulls_disallowed (id BIGINT, customer_id BIGINT, amount DOUBLE)",
507 )
508 .await?
509 .collect()
510 .await?;
511 ctx.sql("INSERT INTO orders_nulls_disallowed VALUES (1, 1, 100.0), (2, NULL, 200.0)")
512 .await?
513 .collect()
514 .await?;
515
516 let constraint = ForeignKeyConstraint::new(
517 "orders_nulls_disallowed.customer_id",
518 "customers_nulls_disallowed.id",
519 )
520 .allow_nulls(false);
521 let result = constraint.evaluate(&ctx).await?;
522
523 assert_eq!(result.status, ConstraintStatus::Failure);
525
526 Ok(())
527 }
528
529 #[tokio::test]
530 async fn test_foreign_key_with_nulls_allowed() -> Result<()> {
531 let ctx = create_test_context().await?;
532
533 ctx.sql("CREATE TABLE customers_nulls_allowed (id BIGINT, name STRING)")
534 .await?
535 .collect()
536 .await?;
537 ctx.sql("INSERT INTO customers_nulls_allowed VALUES (1, 'Alice')")
538 .await?
539 .collect()
540 .await?;
541 ctx.sql("CREATE TABLE orders_nulls_allowed (id BIGINT, customer_id BIGINT, amount DOUBLE)")
542 .await?
543 .collect()
544 .await?;
545 ctx.sql("INSERT INTO orders_nulls_allowed VALUES (1, 1, 100.0), (2, NULL, 200.0)")
546 .await?
547 .collect()
548 .await?;
549
550 let constraint = ForeignKeyConstraint::new(
551 "orders_nulls_allowed.customer_id",
552 "customers_nulls_allowed.id",
553 )
554 .allow_nulls(true);
555 let result = constraint.evaluate(&ctx).await?;
556
557 assert_eq!(result.status, ConstraintStatus::Success);
559
560 Ok(())
561 }
562
563 #[test]
564 fn test_parse_qualified_column() {
565 let constraint = ForeignKeyConstraint::new("orders.customer_id", "customers.id");
566
567 let (table, column) = constraint
568 .parse_qualified_column("orders.customer_id")
569 .unwrap();
570 assert_eq!(table, "orders");
571 assert_eq!(column, "customer_id");
572
573 assert!(constraint.parse_qualified_column("invalid_column").is_err());
575 assert!(constraint.parse_qualified_column("too.many.parts").is_err());
576 }
577
578 #[test]
579 fn test_constraint_configuration() {
580 let constraint = ForeignKeyConstraint::new("orders.customer_id", "customers.id")
581 .allow_nulls(true)
582 .use_left_join(false)
583 .max_violations_reported(50);
584
585 assert_eq!(constraint.child_column(), "orders.customer_id");
586 assert_eq!(constraint.parent_column(), "customers.id");
587 assert!(constraint.allow_nulls);
588 assert!(!constraint.use_left_join);
589 assert_eq!(constraint.max_violations_reported, 50);
590 }
591
592 #[test]
593 fn test_constraint_name() {
594 let constraint = ForeignKeyConstraint::new("orders.customer_id", "customers.id");
595 assert_eq!(constraint.name(), "foreign_key");
596 }
597
598 #[test]
599 fn test_sql_generation() -> Result<()> {
600 let constraint = ForeignKeyConstraint::new("orders.customer_id", "customers.id");
601 let sql =
602 constraint.generate_left_join_query("orders", "customer_id", "customers", "id")?;
603
604 assert!(sql.contains("LEFT JOIN"));
605 assert!(sql.contains("orders.customer_id = customers.id"));
606 assert!(sql.contains("customers.id IS NULL"));
607 assert!(sql.contains("COUNT(*) as total_violations"));
608
609 Ok(())
610 }
611
612 #[test]
613 fn test_sql_generation_with_nulls_allowed() -> Result<()> {
614 let constraint =
615 ForeignKeyConstraint::new("orders.customer_id", "customers.id").allow_nulls(true);
616 let sql =
617 constraint.generate_left_join_query("orders", "customer_id", "customers", "id")?;
618
619 assert!(sql.contains("AND orders.customer_id IS NOT NULL"));
620
621 Ok(())
622 }
623}