1use crate::core::{Constraint, ConstraintResult, ConstraintStatus};
37use crate::error::{Result, TermError};
38use crate::security::SqlSecurity;
39use arrow::array::{Array, Float64Array, StringArray};
40use async_trait::async_trait;
41use datafusion::prelude::*;
42use serde::{Deserialize, Serialize};
43use tracing::{debug, instrument, warn};
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct CrossTableSumConstraint {
60 left_column: String,
62 right_column: String,
64 group_by_columns: Vec<String>,
66 tolerance: f64,
68 max_violations_reported: usize,
70}
71
72impl CrossTableSumConstraint {
73 pub fn new(left_column: impl Into<String>, right_column: impl Into<String>) -> Self {
88 Self {
89 left_column: left_column.into(),
90 right_column: right_column.into(),
91 group_by_columns: Vec::new(),
92 tolerance: 0.0,
93 max_violations_reported: 100,
94 }
95 }
96
97 pub fn group_by(mut self, columns: Vec<impl Into<String>>) -> Self {
111 self.group_by_columns = columns.into_iter().map(Into::into).collect();
112 self
113 }
114
115 pub fn tolerance(mut self, tolerance: f64) -> Self {
130 self.tolerance = tolerance.abs(); self
132 }
133
134 pub fn max_violations_reported(mut self, max_violations: usize) -> Self {
138 self.max_violations_reported = max_violations;
139 self
140 }
141
142 pub fn left_column(&self) -> &str {
144 &self.left_column
145 }
146
147 pub fn right_column(&self) -> &str {
149 &self.right_column
150 }
151
152 pub fn group_by_columns(&self) -> &[String] {
154 &self.group_by_columns
155 }
156
157 fn parse_qualified_column(&self, qualified_column: &str) -> Result<(String, String)> {
159 let parts: Vec<&str> = qualified_column.split('.').collect();
160 if parts.len() != 2 {
161 return Err(TermError::constraint_evaluation(
162 "cross_table_sum",
163 format!("Column must be qualified (table.column): '{qualified_column}'"),
164 ));
165 }
166
167 let table = parts[0].to_string();
168 let column = parts[1].to_string();
169
170 SqlSecurity::validate_identifier(&table)?;
172 SqlSecurity::validate_identifier(&column)?;
173
174 Ok((table, column))
175 }
176
177 fn validate_group_by_columns(&self) -> Result<()> {
179 for column in &self.group_by_columns {
180 SqlSecurity::validate_identifier(column)?;
181 }
182 Ok(())
183 }
184
185 fn generate_validation_query(
192 &self,
193 left_table: &str,
194 left_col: &str,
195 right_table: &str,
196 right_col: &str,
197 ) -> Result<String> {
198 if self.group_by_columns.is_empty() {
199 let tolerance = self.tolerance;
201 let sql = format!(
202 "SELECT
203 1 as total_groups,
204 CASE WHEN ABS(left_total - right_total) > {tolerance}
205 THEN 1 ELSE 0 END as violating_groups,
206 left_total as total_left_sum,
207 right_total as total_right_sum,
208 ABS(left_total - right_total) as max_difference
209 FROM (
210 SELECT
211 COALESCE((SELECT SUM({left_col}) FROM {left_table}), 0.0) as left_total,
212 COALESCE((SELECT SUM({right_col}) FROM {right_table}), 0.0) as right_total
213 ) totals"
214 );
215 debug!("Generated optimized non-grouped cross-table sum query: {sql}");
216 Ok(sql)
217 } else {
218 let group_columns = self
220 .group_by_columns
221 .iter()
222 .map(|col| col.to_string())
223 .collect::<Vec<_>>();
224
225 let left_group_select = group_columns
226 .iter()
227 .map(|col| format!("{left_table}.{col}"))
228 .collect::<Vec<_>>()
229 .join(", ");
230
231 let right_group_select = group_columns
232 .iter()
233 .map(|col| format!("{right_table}.{col}"))
234 .collect::<Vec<_>>()
235 .join(", ");
236
237 let _group_by_clause = group_columns
238 .iter()
239 .map(|col| col.to_string())
240 .collect::<Vec<_>>()
241 .join(", ");
242
243 let tolerance = self.tolerance;
245 let join_condition = group_columns
246 .iter()
247 .map(|col| format!("l.{col} = r.{col}"))
248 .collect::<Vec<_>>()
249 .join(" AND ");
250 let sql = format!(
251 "WITH left_sums AS (
252 SELECT {left_group_select},
253 COALESCE(SUM({left_table}.{left_col}), 0.0) as left_sum
254 FROM {left_table}
255 GROUP BY {left_group_select}
256 ),
257 right_sums AS (
258 SELECT {right_group_select},
259 COALESCE(SUM({right_table}.{right_col}), 0.0) as right_sum
260 FROM {right_table}
261 GROUP BY {right_group_select}
262 ),
263 combined_data AS (
264 SELECT
265 COALESCE(l.left_sum, 0.0) as total_left_sum,
266 COALESCE(r.right_sum, 0.0) as total_right_sum,
267 ABS(COALESCE(l.left_sum, 0.0) - COALESCE(r.right_sum, 0.0)) as difference,
268 CASE WHEN ABS(COALESCE(l.left_sum, 0.0) - COALESCE(r.right_sum, 0.0)) > {tolerance}
269 THEN 1 ELSE 0 END as is_violation
270 FROM left_sums l
271 FULL OUTER JOIN right_sums r ON {join_condition}
272 )
273 SELECT
274 COUNT(*) as total_groups,
275 SUM(is_violation) as violating_groups,
276 SUM(total_left_sum) as total_left_sum,
277 SUM(total_right_sum) as total_right_sum,
278 MAX(difference) as max_difference
279 FROM combined_data"
280 );
281 debug!("Generated optimized grouped cross-table sum query: {sql}");
282 Ok(sql)
283 }
284 }
285
286 fn generate_violations_query(
288 &self,
289 left_table: &str,
290 left_col: &str,
291 right_table: &str,
292 right_col: &str,
293 ) -> Result<String> {
294 if self.max_violations_reported == 0 {
295 return Ok(String::new());
296 }
297
298 if self.group_by_columns.is_empty() {
299 let tolerance = self.tolerance;
301 let limit = self.max_violations_reported;
302 let sql = format!(
303 "SELECT
304 'ALL' as group_key,
305 left_total as left_sum,
306 right_total as right_sum,
307 ABS(left_total - right_total) as difference
308 FROM (
309 SELECT
310 COALESCE((SELECT SUM({left_col}) FROM {left_table}), 0.0) as left_total,
311 COALESCE((SELECT SUM({right_col}) FROM {right_table}), 0.0) as right_total
312 ) totals
313 WHERE ABS(left_total - right_total) > {tolerance}
314 LIMIT {limit}"
315 );
316 debug!("Generated optimized non-grouped violations query: {sql}");
317 Ok(sql)
318 } else {
319 let group_columns = self
321 .group_by_columns
322 .iter()
323 .map(|col| col.to_string())
324 .collect::<Vec<_>>();
325
326 let left_group_select = group_columns
327 .iter()
328 .map(|col| format!("{left_table}.{col}"))
329 .collect::<Vec<_>>()
330 .join(", ");
331
332 let right_group_select = group_columns
333 .iter()
334 .map(|col| format!("{right_table}.{col}"))
335 .collect::<Vec<_>>()
336 .join(", ");
337
338 let group_key_concat = if group_columns.len() == 1 {
339 format!(
340 "CAST(COALESCE(l.{}, r.{}) AS STRING)",
341 group_columns[0], group_columns[0]
342 )
343 } else {
344 format!(
345 "CONCAT({})",
346 group_columns
347 .iter()
348 .map(|col| format!("CAST(COALESCE(l.{col}, r.{col}) AS STRING)"))
349 .collect::<Vec<_>>()
350 .join(", '|', ")
351 )
352 };
353
354 let _group_by_clause = group_columns.join(", ");
355
356 let tolerance = self.tolerance;
357 let limit = self.max_violations_reported;
358 let join_condition = group_columns
359 .iter()
360 .map(|col| format!("l.{col} = r.{col}"))
361 .collect::<Vec<_>>()
362 .join(" AND ");
363 let sql = format!(
364 "WITH left_sums AS (
365 SELECT {left_group_select},
366 COALESCE(SUM({left_table}.{left_col}), 0.0) as left_sum
367 FROM {left_table}
368 GROUP BY {left_group_select}
369 ),
370 right_sums AS (
371 SELECT {right_group_select},
372 COALESCE(SUM({right_table}.{right_col}), 0.0) as right_sum
373 FROM {right_table}
374 GROUP BY {right_group_select}
375 )
376 SELECT
377 {group_key_concat} as group_key,
378 COALESCE(l.left_sum, 0.0) as left_sum,
379 COALESCE(r.right_sum, 0.0) as right_sum,
380 ABS(COALESCE(l.left_sum, 0.0) - COALESCE(r.right_sum, 0.0)) as difference
381 FROM left_sums l
382 FULL OUTER JOIN right_sums r ON {join_condition}
383 WHERE ABS(COALESCE(l.left_sum, 0.0) - COALESCE(r.right_sum, 0.0)) > {tolerance}
384 ORDER BY ABS(COALESCE(l.left_sum, 0.0) - COALESCE(r.right_sum, 0.0)) DESC
385 LIMIT {limit}"
386 );
387 debug!("Generated optimized grouped violations query: {sql}");
388 Ok(sql)
389 }
390 }
391
392 async fn collect_violation_examples_simple(
399 &self,
400 ctx: &SessionContext,
401 left_table: &str,
402 left_col: &str,
403 right_table: &str,
404 right_col: &str,
405 ) -> Result<Vec<String>> {
406 if !self.group_by_columns.is_empty() {
411 debug!("Skipping violation example collection for grouped constraint due to DataFusion limitations");
412 return Ok(Vec::new());
413 }
414
415 let violations_sql =
416 self.generate_violations_query(left_table, left_col, right_table, right_col)?;
417 if violations_sql.is_empty() {
418 return Ok(Vec::new());
419 }
420
421 debug!("Executing simple violations query");
422
423 let violations_df = ctx.sql(&violations_sql).await.map_err(|e| {
424 TermError::constraint_evaluation(
425 "cross_table_sum",
426 format!("Failed to execute violations query: {e}"),
427 )
428 })?;
429
430 let batches = violations_df.collect().await.map_err(|e| {
431 TermError::constraint_evaluation(
432 "cross_table_sum",
433 format!("Failed to collect violation examples: {e}"),
434 )
435 })?;
436
437 let mut violation_examples = Vec::with_capacity(self.max_violations_reported);
438
439 for batch in batches {
440 for i in 0..batch.num_rows() {
441 if violation_examples.len() >= self.max_violations_reported {
442 break;
443 }
444
445 if let (Some(group_key), Some(left_sum), Some(right_sum), Some(diff)) = (
446 batch.column(0).as_any().downcast_ref::<StringArray>(),
447 batch.column(1).as_any().downcast_ref::<Float64Array>(),
448 batch.column(2).as_any().downcast_ref::<Float64Array>(),
449 batch.column(3).as_any().downcast_ref::<Float64Array>(),
450 ) {
451 if !group_key.is_null(i) {
452 violation_examples.push(format!(
453 "Group '{}': {} = {:.4}, {} = {:.4} (diff: {:.4})",
454 group_key.value(i),
455 self.left_column,
456 left_sum.value(i),
457 self.right_column,
458 right_sum.value(i),
459 diff.value(i)
460 ));
461 }
462 }
463 }
464 }
465
466 debug!("Collected {} violation examples", violation_examples.len());
467 Ok(violation_examples)
468 }
469}
470
471#[async_trait]
472impl Constraint for CrossTableSumConstraint {
473 #[instrument(skip(self, ctx), fields(constraint = "cross_table_sum"))]
474 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
475 debug!(
476 "Evaluating cross-table sum constraint: {} vs {}",
477 self.left_column, self.right_column
478 );
479
480 let (left_table, left_col) = self.parse_qualified_column(&self.left_column)?;
482 let (right_table, right_col) = self.parse_qualified_column(&self.right_column)?;
483
484 self.validate_group_by_columns()?;
486
487 let sql =
489 self.generate_validation_query(&left_table, &left_col, &right_table, &right_col)?;
490 let df = ctx.sql(&sql).await.map_err(|e| {
491 TermError::constraint_evaluation(
492 "cross_table_sum",
493 format!("Cross-table sum validation query failed: {e}"),
494 )
495 })?;
496
497 let batches = df.collect().await.map_err(|e| {
498 TermError::constraint_evaluation(
499 "cross_table_sum",
500 format!("Failed to collect cross-table sum results: {e}"),
501 )
502 })?;
503
504 if batches.is_empty() || batches[0].num_rows() == 0 {
505 return Ok(ConstraintResult::skipped(
506 "No data found for cross-table sum comparison",
507 ));
508 }
509
510 let batch = &batches[0];
512 let total_groups = batch
513 .column(0)
514 .as_any()
515 .downcast_ref::<arrow::array::Int64Array>()
516 .ok_or_else(|| {
517 TermError::constraint_evaluation(
518 "cross_table_sum",
519 "Invalid total_groups column type",
520 )
521 })?
522 .value(0);
523
524 let violating_groups = batch
525 .column(1)
526 .as_any()
527 .downcast_ref::<arrow::array::Int64Array>()
528 .ok_or_else(|| {
529 TermError::constraint_evaluation(
530 "cross_table_sum",
531 "Invalid violating_groups column type",
532 )
533 })?
534 .value(0);
535
536 let total_left_sum = batch
537 .column(2)
538 .as_any()
539 .downcast_ref::<Float64Array>()
540 .ok_or_else(|| {
541 TermError::constraint_evaluation(
542 "cross_table_sum",
543 "Invalid total_left_sum column type",
544 )
545 })?
546 .value(0);
547
548 let total_right_sum = batch
549 .column(3)
550 .as_any()
551 .downcast_ref::<Float64Array>()
552 .ok_or_else(|| {
553 TermError::constraint_evaluation(
554 "cross_table_sum",
555 "Invalid total_right_sum column type",
556 )
557 })?
558 .value(0);
559
560 let max_difference = batch
561 .column(4)
562 .as_any()
563 .downcast_ref::<Float64Array>()
564 .ok_or_else(|| {
565 TermError::constraint_evaluation(
566 "cross_table_sum",
567 "Invalid max_difference column type",
568 )
569 })?
570 .value(0);
571
572 if violating_groups == 0 {
573 debug!("Cross-table sum constraint passed: all groups match within tolerance");
574 return Ok(ConstraintResult::success_with_metric(max_difference));
575 }
576
577 let mut violation_examples = Vec::new();
579 if self.max_violations_reported > 0 {
580 violation_examples = self
581 .collect_violation_examples_simple(
582 ctx,
583 &left_table,
584 &left_col,
585 &right_table,
586 &right_col,
587 )
588 .await?;
589 }
590
591 let tolerance_text = if self.tolerance > 0.0 {
593 format!(" (tolerance: {:.4})", self.tolerance)
594 } else {
595 " (exact match required)".to_string()
596 };
597
598 let grouping_text = if self.group_by_columns.is_empty() {
599 "overall totals".to_string()
600 } else {
601 format!("groups by [{}]", self.group_by_columns.join(", "))
602 };
603
604 let message = if violation_examples.is_empty() {
605 format!(
606 "Cross-table sum mismatch: {violating_groups}/{total_groups} {grouping_text} failed validation{tolerance_text}, total sums: {total_left_sum} vs {total_right_sum} (max diff: {max_difference:.4})"
607 )
608 } else {
609 let examples_str = if violation_examples.len() <= 3 {
610 violation_examples.join("; ")
611 } else {
612 format!(
613 "{}; ... ({} more)",
614 violation_examples[..3].join("; "),
615 violation_examples.len() - 3
616 )
617 };
618
619 format!(
620 "Cross-table sum mismatch: {violating_groups}/{total_groups} {grouping_text} failed validation{tolerance_text}. Examples: [{examples_str}]"
621 )
622 };
623
624 warn!("{}", message);
625
626 Ok(ConstraintResult {
627 status: ConstraintStatus::Failure,
628 metric: Some(max_difference),
629 message: Some(message),
630 })
631 }
632
633 fn name(&self) -> &str {
634 "cross_table_sum"
635 }
636}
637
638#[cfg(test)]
639mod tests {
640 use super::*;
641 use crate::test_utils::create_test_context;
642
643 async fn create_test_tables(ctx: &SessionContext, table_suffix: &str) -> Result<()> {
644 let orders_table = format!("orders_{table_suffix}");
645 let payments_table = format!("payments_{table_suffix}");
646
647 ctx.sql(&format!(
649 "CREATE TABLE {orders_table} (id BIGINT, customer_id BIGINT, total DOUBLE)"
650 ))
651 .await?
652 .collect()
653 .await?;
654 ctx.sql(&format!(
655 "INSERT INTO {orders_table} VALUES (1, 1, 100.0), (2, 1, 200.0), (3, 2, 150.0), (4, 2, 300.0)"
656 ))
657 .await?
658 .collect()
659 .await?;
660
661 ctx.sql(&format!(
663 "CREATE TABLE {payments_table} (id BIGINT, customer_id BIGINT, amount DOUBLE)"
664 ))
665 .await?
666 .collect()
667 .await?;
668 ctx.sql(&format!(
669 "INSERT INTO {payments_table} VALUES (1, 1, 300.0), (2, 2, 450.0)"
670 ))
671 .await?
672 .collect()
673 .await?;
674
675 Ok(())
676 }
677
678 #[tokio::test]
679 async fn test_cross_table_sum_success() -> Result<()> {
680 let ctx = create_test_context().await?;
681 create_test_tables(&ctx, "success").await?;
682
683 let constraint =
684 CrossTableSumConstraint::new("orders_success.total", "payments_success.amount")
685 .group_by(vec!["customer_id"]);
686 let result = constraint.evaluate(&ctx).await?;
687
688 assert_eq!(result.status, ConstraintStatus::Success);
689 assert!(result.metric.is_some());
690
691 Ok(())
692 }
693
694 #[tokio::test]
695 async fn test_cross_table_sum_violation() -> Result<()> {
696 let ctx = create_test_context().await?;
697
698 ctx.sql("CREATE TABLE orders_violation (id BIGINT, customer_id BIGINT, total DOUBLE)")
700 .await?
701 .collect()
702 .await?;
703 ctx.sql("INSERT INTO orders_violation VALUES (1, 1, 100.0), (2, 1, 200.0)")
704 .await?
705 .collect()
706 .await?;
707 ctx.sql("CREATE TABLE payments_violation (id BIGINT, customer_id BIGINT, amount DOUBLE)")
708 .await?
709 .collect()
710 .await?;
711 ctx.sql("INSERT INTO payments_violation VALUES (1, 1, 250.0)")
712 .await?
713 .collect()
714 .await?;
715
716 let constraint =
717 CrossTableSumConstraint::new("orders_violation.total", "payments_violation.amount")
718 .group_by(vec!["customer_id"]);
719 let result = constraint.evaluate(&ctx).await?;
720
721 assert_eq!(result.status, ConstraintStatus::Failure);
722 assert!(result.message.is_some());
723 assert!(result.metric.is_some());
724
725 let message = result.message.unwrap();
726 assert!(message.contains("Cross-table sum mismatch"));
727 assert!(message.contains("customer_id"));
728
729 Ok(())
730 }
731
732 #[tokio::test]
733 async fn test_cross_table_sum_with_tolerance() -> Result<()> {
734 let ctx = create_test_context().await?;
735
736 ctx.sql("CREATE TABLE orders_tolerance (id BIGINT, total DOUBLE)")
738 .await?
739 .collect()
740 .await?;
741 ctx.sql("INSERT INTO orders_tolerance VALUES (1, 100.005)")
742 .await?
743 .collect()
744 .await?;
745 ctx.sql("CREATE TABLE payments_tolerance (id BIGINT, amount DOUBLE)")
746 .await?
747 .collect()
748 .await?;
749 ctx.sql("INSERT INTO payments_tolerance VALUES (1, 100.001)")
750 .await?
751 .collect()
752 .await?;
753
754 let constraint_no_tolerance =
756 CrossTableSumConstraint::new("orders_tolerance.total", "payments_tolerance.amount");
757 let result = constraint_no_tolerance.evaluate(&ctx).await?;
758 assert_eq!(result.status, ConstraintStatus::Failure);
759
760 let constraint_with_tolerance =
762 CrossTableSumConstraint::new("orders_tolerance.total", "payments_tolerance.amount")
763 .tolerance(0.01);
764 let result = constraint_with_tolerance.evaluate(&ctx).await?;
765 assert_eq!(result.status, ConstraintStatus::Success);
766
767 Ok(())
768 }
769
770 #[tokio::test]
771 async fn test_cross_table_sum_no_grouping() -> Result<()> {
772 let ctx = create_test_context().await?;
773 create_test_tables(&ctx, "no_grouping").await?;
774
775 let constraint =
776 CrossTableSumConstraint::new("orders_no_grouping.total", "payments_no_grouping.amount");
777 let result = constraint.evaluate(&ctx).await?;
778
779 assert_eq!(result.status, ConstraintStatus::Success);
780
781 Ok(())
782 }
783
784 #[test]
785 fn test_parse_qualified_column() {
786 let constraint = CrossTableSumConstraint::new("orders.total", "payments.amount");
787
788 let (table, column) = constraint.parse_qualified_column("orders.total").unwrap();
789 assert_eq!(table, "orders");
790 assert_eq!(column, "total");
791
792 assert!(constraint.parse_qualified_column("invalid_column").is_err());
794 assert!(constraint.parse_qualified_column("too.many.parts").is_err());
795 }
796
797 #[test]
798 fn test_constraint_configuration() {
799 let constraint = CrossTableSumConstraint::new("orders.total", "payments.amount")
800 .group_by(vec!["customer_id", "order_date"])
801 .tolerance(0.01)
802 .max_violations_reported(50);
803
804 assert_eq!(constraint.left_column(), "orders.total");
805 assert_eq!(constraint.right_column(), "payments.amount");
806 assert_eq!(
807 constraint.group_by_columns(),
808 &["customer_id", "order_date"]
809 );
810 assert_eq!(constraint.tolerance, 0.01);
811 assert_eq!(constraint.max_violations_reported, 50);
812 }
813
814 #[test]
815 fn test_constraint_name() {
816 let constraint = CrossTableSumConstraint::new("orders.total", "payments.amount");
817 assert_eq!(constraint.name(), "cross_table_sum");
818 }
819}