1use crate::core::{Constraint, ConstraintResult, ConstraintStatus};
38use crate::error::{Result, TermError};
39use crate::security::SqlSecurity;
40use arrow::array::{Array, Float64Array};
41use async_trait::async_trait;
42use datafusion::prelude::*;
43use serde::{Deserialize, Serialize};
44use tracing::{debug, instrument, warn};
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct JoinCoverageConstraint {
62 left_table: String,
64 right_table: String,
66 join_keys: Vec<(String, String)>,
68 expected_match_rate: f64,
70 coverage_type: CoverageType,
72 distinct_only: bool,
74 max_examples_reported: usize,
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
80pub enum CoverageType {
81 LeftCoverage,
83 RightCoverage,
85 BidirectionalCoverage,
87}
88
89impl JoinCoverageConstraint {
90 pub fn new(left_table: impl Into<String>, right_table: impl Into<String>) -> Self {
105 Self {
106 left_table: left_table.into(),
107 right_table: right_table.into(),
108 join_keys: Vec::new(),
109 expected_match_rate: 1.0,
110 coverage_type: CoverageType::LeftCoverage,
111 distinct_only: false,
112 max_examples_reported: 100,
113 }
114 }
115
116 pub fn on(mut self, left_column: impl Into<String>, right_column: impl Into<String>) -> Self {
123 self.join_keys = vec![(left_column.into(), right_column.into())];
124 self
125 }
126
127 pub fn on_multiple(mut self, keys: Vec<(impl Into<String>, impl Into<String>)>) -> Self {
133 self.join_keys = keys
134 .into_iter()
135 .map(|(l, r)| (l.into(), r.into()))
136 .collect();
137 self
138 }
139
140 pub fn expect_match_rate(mut self, rate: f64) -> Self {
146 self.expected_match_rate = rate.clamp(0.0, 1.0);
147 self
148 }
149
150 pub fn coverage_type(mut self, coverage_type: CoverageType) -> Self {
152 self.coverage_type = coverage_type;
153 self
154 }
155
156 pub fn distinct_only(mut self, distinct: bool) -> Self {
158 self.distinct_only = distinct;
159 self
160 }
161
162 pub fn max_examples_reported(mut self, max_examples: usize) -> Self {
164 self.max_examples_reported = max_examples;
165 self
166 }
167
168 fn validate_identifiers(&self) -> Result<()> {
170 SqlSecurity::validate_identifier(&self.left_table)?;
171 SqlSecurity::validate_identifier(&self.right_table)?;
172
173 for (left_col, right_col) in &self.join_keys {
174 SqlSecurity::validate_identifier(left_col)?;
175 SqlSecurity::validate_identifier(right_col)?;
176 }
177
178 Ok(())
179 }
180
181 fn generate_coverage_query(&self) -> Result<String> {
183 self.validate_identifiers()?;
184
185 if self.join_keys.is_empty() {
186 return Err(TermError::constraint_evaluation(
187 "join_coverage",
188 "No join keys specified. Use .on() or .on_multiple() to set join keys",
189 ));
190 }
191
192 let join_condition = self
193 .join_keys
194 .iter()
195 .map(|(l, r)| format!("{}.{l} = {}.{r}", self.left_table, self.right_table))
196 .collect::<Vec<_>>()
197 .join(" AND ");
198
199 let count_expr = if self.distinct_only {
200 format!(
201 "COUNT(DISTINCT {}.{})",
202 self.left_table, self.join_keys[0].0
203 )
204 } else {
205 "COUNT(*)".to_string()
206 };
207
208 let sql = match self.coverage_type {
209 CoverageType::LeftCoverage => {
210 format!(
212 "WITH coverage_stats AS (
213 SELECT
214 {count_expr} as total_left,
215 SUM(CASE WHEN {}.{} IS NOT NULL THEN 1 ELSE 0 END) as matched_left
216 FROM {}
217 LEFT JOIN {} ON {join_condition}
218 )
219 SELECT
220 total_left,
221 matched_left,
222 CAST(matched_left AS DOUBLE) / CAST(total_left AS DOUBLE) as match_rate
223 FROM coverage_stats",
224 self.right_table, self.join_keys[0].1, self.left_table, self.right_table
225 )
226 }
227 CoverageType::RightCoverage => {
228 format!(
230 "WITH coverage_stats AS (
231 SELECT
232 {count_expr} as total_right,
233 SUM(CASE WHEN {}.{} IS NOT NULL THEN 1 ELSE 0 END) as matched_right
234 FROM {}
235 RIGHT JOIN {} ON {join_condition}
236 )
237 SELECT
238 total_right,
239 matched_right,
240 CAST(matched_right AS DOUBLE) / CAST(total_right AS DOUBLE) as match_rate
241 FROM coverage_stats",
242 self.left_table, self.join_keys[0].0, self.right_table, self.left_table
243 )
244 }
245 CoverageType::BidirectionalCoverage => {
246 format!(
248 "WITH left_coverage AS (
249 SELECT
250 COUNT(*) as total_left,
251 SUM(CASE WHEN {}.{} IS NOT NULL THEN 1 ELSE 0 END) as matched_left
252 FROM {}
253 LEFT JOIN {} ON {join_condition}
254 ),
255 right_coverage AS (
256 SELECT
257 COUNT(*) as total_right,
258 SUM(CASE WHEN {}.{} IS NOT NULL THEN 1 ELSE 0 END) as matched_right
259 FROM {}
260 RIGHT JOIN {} ON {join_condition}
261 )
262 SELECT
263 l.total_left,
264 l.matched_left,
265 r.total_right,
266 r.matched_right,
267 LEAST(
268 CAST(l.matched_left AS DOUBLE) / CAST(l.total_left AS DOUBLE),
269 CAST(r.matched_right AS DOUBLE) / CAST(r.total_right AS DOUBLE)
270 ) as match_rate
271 FROM left_coverage l, right_coverage r",
272 self.right_table,
273 self.join_keys[0].1,
274 self.left_table,
275 self.right_table,
276 self.left_table,
277 self.join_keys[0].0,
278 self.right_table,
279 self.left_table
280 )
281 }
282 };
283
284 debug!("Generated join coverage query: {}", sql);
285 Ok(sql)
286 }
287
288 fn generate_unmatched_query(&self) -> Result<String> {
290 if self.max_examples_reported == 0 {
291 return Ok(String::new());
292 }
293
294 self.validate_identifiers()?;
295
296 let join_condition = self
297 .join_keys
298 .iter()
299 .map(|(l, r)| format!("{}.{l} = {}.{r}", self.left_table, self.right_table))
300 .collect::<Vec<_>>()
301 .join(" AND ");
302
303 let key_columns = self
304 .join_keys
305 .iter()
306 .map(|(l, _)| format!("{}.{l}", self.left_table))
307 .collect::<Vec<_>>()
308 .join(", ");
309
310 let sql = format!(
311 "SELECT DISTINCT {key_columns}
312 FROM {}
313 LEFT JOIN {} ON {join_condition}
314 WHERE {}.{} IS NULL
315 LIMIT {}",
316 self.left_table,
317 self.right_table,
318 self.right_table,
319 self.join_keys[0].1,
320 self.max_examples_reported
321 );
322
323 Ok(sql)
324 }
325}
326
327#[async_trait]
328impl Constraint for JoinCoverageConstraint {
329 #[instrument(skip(self, ctx), fields(constraint = "join_coverage"))]
330 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
331 debug!(
332 "Evaluating join coverage: {} <-> {} on {:?}",
333 self.left_table, self.right_table, self.join_keys
334 );
335
336 let sql = self.generate_coverage_query()?;
338 let df = ctx.sql(&sql).await.map_err(|e| {
339 TermError::constraint_evaluation(
340 "join_coverage",
341 format!("Join coverage query failed: {e}"),
342 )
343 })?;
344
345 let batches = df.collect().await.map_err(|e| {
346 TermError::constraint_evaluation(
347 "join_coverage",
348 format!("Failed to collect join coverage results: {e}"),
349 )
350 })?;
351
352 if batches.is_empty() || batches[0].num_rows() == 0 {
353 return Err(TermError::constraint_evaluation(
354 "join_coverage",
355 "No results from join coverage query",
356 ));
357 }
358
359 let batch = &batches[0];
361 let match_rate_col = batch
362 .column(batch.num_columns() - 1) .as_any()
364 .downcast_ref::<Float64Array>()
365 .ok_or_else(|| {
366 TermError::constraint_evaluation("join_coverage", "Invalid match rate column type")
367 })?;
368
369 let match_rate = match_rate_col.value(0);
370
371 debug!(
372 "Join coverage: {:.2}% (expected: {:.2}%)",
373 match_rate * 100.0,
374 self.expected_match_rate * 100.0
375 );
376
377 if match_rate >= self.expected_match_rate {
379 return Ok(ConstraintResult::success_with_metric(match_rate));
380 }
381
382 let coverage_desc = match self.coverage_type {
384 CoverageType::LeftCoverage => format!("{} -> {}", self.left_table, self.right_table),
385 CoverageType::RightCoverage => format!("{} <- {}", self.left_table, self.right_table),
386 CoverageType::BidirectionalCoverage => {
387 format!("{} <-> {}", self.left_table, self.right_table)
388 }
389 };
390
391 let unmatched_query = self.generate_unmatched_query()?;
393 let examples_msg = if !unmatched_query.is_empty() {
394 match ctx.sql(&unmatched_query).await {
395 Ok(df) => match df.collect().await {
396 Ok(batches) if !batches.is_empty() && batches[0].num_rows() > 0 => {
397 let examples_count = batches[0].num_rows();
398 format!(" ({examples_count} unmatched examples found)")
399 }
400 _ => String::new(),
401 },
402 _ => String::new(),
403 }
404 } else {
405 String::new()
406 };
407
408 let message = format!(
409 "Join coverage constraint failed: {coverage_desc} coverage is {:.2}% (expected: {:.2}%){examples_msg}",
410 match_rate * 100.0,
411 self.expected_match_rate * 100.0
412 );
413
414 warn!("{}", message);
415
416 Ok(ConstraintResult {
417 status: ConstraintStatus::Failure,
418 metric: Some(match_rate),
419 message: Some(message),
420 })
421 }
422
423 fn name(&self) -> &str {
424 "join_coverage"
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use crate::test_utils::create_test_context;
432
433 #[tokio::test]
434 async fn test_join_coverage_success() -> Result<()> {
435 let ctx = create_test_context().await?;
436
437 ctx.sql("CREATE TABLE orders_cov (id BIGINT, customer_id BIGINT)")
439 .await?
440 .collect()
441 .await?;
442 ctx.sql("INSERT INTO orders_cov VALUES (1, 1), (2, 2), (3, 3)")
443 .await?
444 .collect()
445 .await?;
446 ctx.sql("CREATE TABLE customers_cov (id BIGINT, name STRING)")
447 .await?
448 .collect()
449 .await?;
450 ctx.sql("INSERT INTO customers_cov VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Charlie')")
451 .await?
452 .collect()
453 .await?;
454
455 let constraint = JoinCoverageConstraint::new("orders_cov", "customers_cov")
456 .on("customer_id", "id")
457 .expect_match_rate(1.0);
458
459 let result = constraint.evaluate(&ctx).await?;
460 assert_eq!(result.status, ConstraintStatus::Success);
461 assert_eq!(result.metric, Some(1.0));
462
463 Ok(())
464 }
465
466 #[tokio::test]
467 async fn test_join_coverage_partial() -> Result<()> {
468 let ctx = create_test_context().await?;
469
470 ctx.sql("CREATE TABLE orders_partial (id BIGINT, customer_id BIGINT)")
472 .await?
473 .collect()
474 .await?;
475 ctx.sql("INSERT INTO orders_partial VALUES (1, 1), (2, 2), (3, 999)")
476 .await?
477 .collect()
478 .await?;
479 ctx.sql("CREATE TABLE customers_partial (id BIGINT, name STRING)")
480 .await?
481 .collect()
482 .await?;
483 ctx.sql("INSERT INTO customers_partial VALUES (1, 'Alice'), (2, 'Bob')")
484 .await?
485 .collect()
486 .await?;
487
488 let constraint = JoinCoverageConstraint::new("orders_partial", "customers_partial")
489 .on("customer_id", "id")
490 .expect_match_rate(0.6); let result = constraint.evaluate(&ctx).await?;
493 assert_eq!(result.status, ConstraintStatus::Success);
494 assert!((result.metric.unwrap() - 0.666).abs() < 0.01); Ok(())
497 }
498
499 #[tokio::test]
500 async fn test_join_coverage_failure() -> Result<()> {
501 let ctx = create_test_context().await?;
502
503 ctx.sql("CREATE TABLE orders_low (id BIGINT, customer_id BIGINT)")
505 .await?
506 .collect()
507 .await?;
508 ctx.sql("INSERT INTO orders_low VALUES (1, 999), (2, 998), (3, 997)")
509 .await?
510 .collect()
511 .await?;
512 ctx.sql("CREATE TABLE customers_low (id BIGINT, name STRING)")
513 .await?
514 .collect()
515 .await?;
516 ctx.sql("INSERT INTO customers_low VALUES (1, 'Alice')")
517 .await?
518 .collect()
519 .await?;
520
521 let constraint = JoinCoverageConstraint::new("orders_low", "customers_low")
522 .on("customer_id", "id")
523 .expect_match_rate(0.9); let result = constraint.evaluate(&ctx).await?;
526 assert_eq!(result.status, ConstraintStatus::Failure);
527 assert_eq!(result.metric, Some(0.0));
528 assert!(result.message.is_some());
529
530 Ok(())
531 }
532
533 #[test]
534 fn test_constraint_configuration() {
535 let constraint = JoinCoverageConstraint::new("orders", "customers")
536 .on("customer_id", "id")
537 .expect_match_rate(0.95)
538 .coverage_type(CoverageType::BidirectionalCoverage)
539 .distinct_only(true)
540 .max_examples_reported(50);
541
542 assert_eq!(constraint.left_table, "orders");
543 assert_eq!(constraint.right_table, "customers");
544 assert_eq!(constraint.expected_match_rate, 0.95);
545 assert_eq!(
546 constraint.coverage_type,
547 CoverageType::BidirectionalCoverage
548 );
549 assert!(constraint.distinct_only);
550 assert_eq!(constraint.max_examples_reported, 50);
551 }
552
553 #[test]
554 fn test_composite_keys() {
555 let constraint = JoinCoverageConstraint::new("orders", "products")
556 .on_multiple(vec![("product_id", "id"), ("variant", "variant_code")])
557 .expect_match_rate(0.98);
558
559 assert_eq!(constraint.join_keys.len(), 2);
560 assert_eq!(
561 constraint.join_keys[0],
562 ("product_id".to_string(), "id".to_string())
563 );
564 assert_eq!(
565 constraint.join_keys[1],
566 ("variant".to_string(), "variant_code".to_string())
567 );
568 }
569}