1use crate::analyzers::suggestions::{ConstraintParameter, SuggestionPriority};
39use crate::core::Check;
40use datafusion::arrow::datatypes::{DataType, Schema};
41use datafusion::prelude::*;
42use serde::{Deserialize, Serialize};
43use std::collections::HashMap;
44use std::sync::Arc;
45use tracing::{info, instrument};
46
47pub struct SchemaAnalyzer<'a> {
52 ctx: &'a SessionContext,
53 naming_patterns: NamingPatterns,
54}
55
56#[derive(Debug, Clone)]
58struct NamingPatterns {
59 foreign_key_suffixes: Vec<String>,
61 temporal_patterns: Vec<String>,
63 amount_patterns: Vec<String>,
65 quantity_patterns: Vec<String>,
67}
68
69impl Default for NamingPatterns {
70 fn default() -> Self {
71 Self {
72 foreign_key_suffixes: vec![
73 "_id".to_string(),
74 "_key".to_string(),
75 "_fk".to_string(),
76 "_ref".to_string(),
77 ],
78 temporal_patterns: vec![
79 "_at".to_string(),
80 "_date".to_string(),
81 "_time".to_string(),
82 "_timestamp".to_string(),
83 "created".to_string(),
84 "updated".to_string(),
85 "modified".to_string(),
86 "processed".to_string(),
87 "completed".to_string(),
88 ],
89 amount_patterns: vec![
90 "amount".to_string(),
91 "total".to_string(),
92 "price".to_string(),
93 "cost".to_string(),
94 "payment".to_string(),
95 "revenue".to_string(),
96 "balance".to_string(),
97 ],
98 quantity_patterns: vec![
99 "quantity".to_string(),
100 "qty".to_string(),
101 "count".to_string(),
102 "units".to_string(),
103 "items".to_string(),
104 ],
105 }
106 }
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct CrossTableSuggestion {
112 pub constraint_type: String,
114 pub tables: Vec<String>,
116 pub columns: HashMap<String, Vec<String>>,
118 pub confidence: f64,
120 pub rationale: String,
122 pub priority: SuggestionPriority,
124 pub parameters: HashMap<String, ConstraintParameter>,
126}
127
128impl<'a> SchemaAnalyzer<'a> {
129 pub fn new(ctx: &'a SessionContext) -> Self {
131 Self {
132 ctx,
133 naming_patterns: NamingPatterns::default(),
134 }
135 }
136
137 #[instrument(skip(self))]
139 pub async fn analyze_all_tables(&self) -> crate::error::Result<Vec<CrossTableSuggestion>> {
140 let mut suggestions = Vec::new();
141
142 let catalog = self.ctx.catalog("datafusion").unwrap();
144 let schema = catalog.schema("public").unwrap();
145 let table_names: Vec<String> = schema.table_names();
146
147 info!(
148 "Analyzing {} tables for constraint suggestions",
149 table_names.len()
150 );
151
152 let mut table_schemas = HashMap::new();
154 for table_name in &table_names {
155 if let Ok(Some(table)) = schema.table(table_name).await {
156 let schema = table.schema();
157 table_schemas.insert(table_name.clone(), schema);
158 }
159 }
160
161 suggestions.extend(self.analyze_foreign_keys(&table_schemas));
163
164 suggestions.extend(self.analyze_temporal_constraints(&table_schemas));
166
167 suggestions.extend(self.analyze_financial_consistency(&table_schemas));
169
170 suggestions.extend(self.analyze_join_coverage(&table_schemas));
172
173 suggestions.sort_by(|a, b| match (&a.priority, &b.priority) {
175 (SuggestionPriority::Critical, SuggestionPriority::Critical) => {
176 b.confidence.partial_cmp(&a.confidence).unwrap()
177 }
178 (SuggestionPriority::Critical, _) => std::cmp::Ordering::Less,
179 (_, SuggestionPriority::Critical) => std::cmp::Ordering::Greater,
180 _ => b.confidence.partial_cmp(&a.confidence).unwrap(),
181 });
182
183 Ok(suggestions)
184 }
185
186 fn analyze_foreign_keys(
188 &self,
189 schemas: &HashMap<String, Arc<Schema>>,
190 ) -> Vec<CrossTableSuggestion> {
191 let mut suggestions = Vec::new();
192
193 for (table_name, schema) in schemas {
194 for field in schema.fields() {
195 if let Some(referenced_table) = self.detect_foreign_key(field.name(), schemas) {
197 if let Some(ref_schema) = schemas.get(&referenced_table) {
199 let ref_column =
200 self.infer_primary_key_column(&referenced_table, ref_schema);
201
202 let mut columns = HashMap::new();
203 columns.insert(table_name.clone(), vec![field.name().to_string()]);
204 columns.insert(referenced_table.clone(), vec![ref_column.clone()]);
205
206 suggestions.push(CrossTableSuggestion {
207 constraint_type: "foreign_key".to_string(),
208 tables: vec![table_name.clone(), referenced_table.clone()],
209 columns,
210 confidence: self.calculate_fk_confidence(field.name(), &referenced_table),
211 rationale: format!(
212 "Column '{}' in '{table_name}' appears to reference '{referenced_table}' based on naming convention",
213 field.name()
214 ),
215 priority: SuggestionPriority::High,
216 parameters: HashMap::new(),
217 });
218 }
219 }
220 }
221 }
222
223 suggestions
224 }
225
226 fn detect_foreign_key(
228 &self,
229 column_name: &str,
230 schemas: &HashMap<String, Arc<Schema>>,
231 ) -> Option<String> {
232 for suffix in &self.naming_patterns.foreign_key_suffixes {
234 if column_name.ends_with(suffix) {
235 let base_name = &column_name[..column_name.len() - suffix.len()];
237
238 for table_name in schemas.keys() {
240 if self.matches_table_name(base_name, table_name) {
241 return Some(table_name.clone());
242 }
243 }
244 }
245 }
246
247 None
248 }
249
250 fn matches_table_name(&self, base_name: &str, table_name: &str) -> bool {
252 if base_name == table_name {
254 return true;
255 }
256
257 if format!("{base_name}s") == table_name {
259 return true;
260 }
261
262 if base_name == format!("{table_name}s") {
264 return true;
265 }
266
267 if base_name.ends_with('y')
269 && table_name == format!("{}ies", &base_name[..base_name.len() - 1])
270 {
271 return true;
272 }
273
274 false
275 }
276
277 fn infer_primary_key_column(&self, table_name: &str, schema: &Arc<Schema>) -> String {
279 let table_id = format!("{table_name}_id");
281 let table_key = format!("{table_name}_key");
282 let common_pk_names = vec!["id", table_id.as_str(), "key", table_key.as_str()];
283
284 for field in schema.fields() {
285 for pk_name in &common_pk_names {
286 if field.name().to_lowercase() == pk_name.to_lowercase() {
287 return field.name().to_string();
288 }
289 }
290 }
291
292 "id".to_string()
294 }
295
296 fn calculate_fk_confidence(&self, column_name: &str, referenced_table: &str) -> f64 {
298 let mut confidence: f64 = 0.5; if column_name.contains(referenced_table)
302 || column_name.contains(&referenced_table[..referenced_table.len().saturating_sub(1)])
303 {
304 confidence += 0.3;
305 }
306
307 if column_name.ends_with("_id") {
309 confidence += 0.2;
310 }
311
312 confidence.min(1.0)
313 }
314
315 fn analyze_temporal_constraints(
317 &self,
318 schemas: &HashMap<String, Arc<Schema>>,
319 ) -> Vec<CrossTableSuggestion> {
320 let mut suggestions = Vec::new();
321
322 for (table_name, schema) in schemas {
323 let temporal_columns = self.find_temporal_columns(schema);
324
325 if temporal_columns.len() >= 2 {
327 for i in 0..temporal_columns.len() {
328 for j in i + 1..temporal_columns.len() {
329 let col1 = &temporal_columns[i];
330 let col2 = &temporal_columns[j];
331
332 let (before, after) = self.infer_temporal_order(col1, col2);
334
335 let mut columns = HashMap::new();
336 columns.insert(table_name.clone(), vec![before.clone(), after.clone()]);
337
338 let mut parameters = HashMap::new();
339 parameters.insert(
340 "validation_type".to_string(),
341 ConstraintParameter::String("before_after".to_string()),
342 );
343
344 suggestions.push(CrossTableSuggestion {
345 constraint_type: "temporal_ordering".to_string(),
346 tables: vec![table_name.clone()],
347 columns,
348 confidence: 0.8,
349 rationale: format!(
350 "Columns '{before}' and '{after}' appear to have a temporal relationship"
351 ),
352 priority: SuggestionPriority::Medium,
353 parameters,
354 });
355 }
356 }
357 }
358
359 for col in &temporal_columns {
361 if col.contains("transaction") || col.contains("order") || col.contains("payment") {
362 let mut columns = HashMap::new();
363 columns.insert(table_name.clone(), vec![col.clone()]);
364
365 let mut parameters = HashMap::new();
366 parameters.insert(
367 "start_time".to_string(),
368 ConstraintParameter::String("09:00".to_string()),
369 );
370 parameters.insert(
371 "end_time".to_string(),
372 ConstraintParameter::String("17:00".to_string()),
373 );
374
375 suggestions.push(CrossTableSuggestion {
376 constraint_type: "business_hours".to_string(),
377 tables: vec![table_name.clone()],
378 columns,
379 confidence: 0.6,
380 rationale: format!(
381 "Column '{col}' may benefit from business hours validation"
382 ),
383 priority: SuggestionPriority::Low,
384 parameters,
385 });
386 }
387 }
388 }
389
390 suggestions
391 }
392
393 fn find_temporal_columns(&self, schema: &Arc<Schema>) -> Vec<String> {
395 let mut temporal_columns = Vec::new();
396
397 for field in schema.fields() {
398 let is_temporal_type = matches!(
400 field.data_type(),
401 DataType::Date32
402 | DataType::Date64
403 | DataType::Timestamp(_, _)
404 | DataType::Time32(_)
405 | DataType::Time64(_)
406 );
407
408 let matches_pattern = self
410 .naming_patterns
411 .temporal_patterns
412 .iter()
413 .any(|pattern| field.name().to_lowercase().contains(pattern));
414
415 if is_temporal_type || matches_pattern {
416 temporal_columns.push(field.name().to_string());
417 }
418 }
419
420 temporal_columns
421 }
422
423 fn infer_temporal_order(&self, col1: &str, col2: &str) -> (String, String) {
425 let order_keywords = vec![
426 ("created", 0),
427 ("started", 1),
428 ("updated", 2),
429 ("modified", 2),
430 ("processed", 3),
431 ("completed", 4),
432 ("finished", 4),
433 ("ended", 5),
434 ];
435
436 let get_order = |col: &str| -> i32 {
437 for (keyword, order) in &order_keywords {
438 if col.to_lowercase().contains(keyword) {
439 return *order;
440 }
441 }
442 100 };
444
445 let order1 = get_order(col1);
446 let order2 = get_order(col2);
447
448 if order1 <= order2 {
449 (col1.to_string(), col2.to_string())
450 } else {
451 (col2.to_string(), col1.to_string())
452 }
453 }
454
455 fn analyze_financial_consistency(
457 &self,
458 schemas: &HashMap<String, Arc<Schema>>,
459 ) -> Vec<CrossTableSuggestion> {
460 let mut suggestions = Vec::new();
461
462 let mut amount_columns: HashMap<String, Vec<String>> = HashMap::new();
464 let mut quantity_columns: HashMap<String, Vec<String>> = HashMap::new();
465
466 for (table_name, schema) in schemas {
467 for field in schema.fields() {
468 if self.is_amount_column(field.name(), field.data_type()) {
469 amount_columns
470 .entry(table_name.clone())
471 .or_default()
472 .push(field.name().to_string());
473 }
474 if self.is_quantity_column(field.name(), field.data_type()) {
475 quantity_columns
476 .entry(table_name.clone())
477 .or_default()
478 .push(field.name().to_string());
479 }
480 }
481 }
482
483 for (table1, cols1) in &amount_columns {
485 for (table2, cols2) in &amount_columns {
486 if table1 < table2 && self.are_tables_related(table1, table2, schemas) {
487 for col1 in cols1 {
488 for col2 in cols2 {
489 if self.are_columns_likely_related(col1, col2) {
490 let mut columns = HashMap::new();
491 columns.insert(table1.clone(), vec![col1.clone()]);
492 columns.insert(table2.clone(), vec![col2.clone()]);
493
494 let mut parameters = HashMap::new();
495 parameters.insert(
496 "tolerance".to_string(),
497 ConstraintParameter::Float(0.01),
498 );
499
500 suggestions.push(CrossTableSuggestion {
501 constraint_type: "cross_table_sum".to_string(),
502 tables: vec![table1.clone(), table2.clone()],
503 columns,
504 confidence: 0.7,
505 rationale: format!(
506 "Financial columns '{table1}.{col1}' and '{table2}.{col2}' may need sum consistency validation"
507 ),
508 priority: SuggestionPriority::High,
509 parameters,
510 });
511 }
512 }
513 }
514 }
515 }
516 }
517
518 suggestions
519 }
520
521 fn is_amount_column(&self, name: &str, data_type: &DataType) -> bool {
523 let is_numeric = matches!(
525 data_type,
526 DataType::Float32
527 | DataType::Float64
528 | DataType::Decimal128(_, _)
529 | DataType::Decimal256(_, _)
530 );
531
532 if !is_numeric {
533 return false;
534 }
535
536 self.naming_patterns
538 .amount_patterns
539 .iter()
540 .any(|pattern| name.to_lowercase().contains(pattern))
541 }
542
543 fn is_quantity_column(&self, name: &str, data_type: &DataType) -> bool {
545 let is_numeric = matches!(
547 data_type,
548 DataType::Int8
549 | DataType::Int16
550 | DataType::Int32
551 | DataType::Int64
552 | DataType::UInt8
553 | DataType::UInt16
554 | DataType::UInt32
555 | DataType::UInt64
556 | DataType::Float32
557 | DataType::Float64
558 );
559
560 if !is_numeric {
561 return false;
562 }
563
564 self.naming_patterns
566 .quantity_patterns
567 .iter()
568 .any(|pattern| name.to_lowercase().contains(pattern))
569 }
570
571 fn are_tables_related(
573 &self,
574 table1: &str,
575 table2: &str,
576 schemas: &HashMap<String, Arc<Schema>>,
577 ) -> bool {
578 if let Some(schema1) = schemas.get(table1) {
580 for field in schema1.fields() {
581 if let Some(ref_table) = self.detect_foreign_key(field.name(), schemas) {
582 if ref_table == table2 {
583 return true;
584 }
585 }
586 }
587 }
588
589 if let Some(schema2) = schemas.get(table2) {
590 for field in schema2.fields() {
591 if let Some(ref_table) = self.detect_foreign_key(field.name(), schemas) {
592 if ref_table == table1 {
593 return true;
594 }
595 }
596 }
597 }
598
599 table1.contains(table2) || table2.contains(table1)
601 }
602
603 fn are_columns_likely_related(&self, col1: &str, col2: &str) -> bool {
605 if col1 == col2 {
607 return true;
608 }
609
610 let keywords = vec!["total", "amount", "sum", "payment", "cost", "price"];
612 for keyword in keywords {
613 if col1.contains(keyword) && col2.contains(keyword) {
614 return true;
615 }
616 }
617
618 false
619 }
620
621 fn analyze_join_coverage(
623 &self,
624 schemas: &HashMap<String, Arc<Schema>>,
625 ) -> Vec<CrossTableSuggestion> {
626 let mut suggestions = Vec::new();
627
628 for (table_name, schema) in schemas {
630 for field in schema.fields() {
631 if let Some(referenced_table) = self.detect_foreign_key(field.name(), schemas) {
632 let mut columns = HashMap::new();
633 columns.insert(table_name.clone(), vec![field.name().to_string()]);
634 columns.insert(referenced_table.clone(), vec!["id".to_string()]); let mut parameters = HashMap::new();
637 parameters.insert(
638 "expected_coverage".to_string(),
639 ConstraintParameter::Float(0.95),
640 );
641
642 suggestions.push(CrossTableSuggestion {
643 constraint_type: "join_coverage".to_string(),
644 tables: vec![table_name.clone(), referenced_table.clone()],
645 columns,
646 confidence: 0.75,
647 rationale: format!(
648 "Join between '{table_name}' and '{referenced_table}' should have high coverage for data quality"
649 ),
650 priority: SuggestionPriority::Medium,
651 parameters,
652 });
653 }
654 }
655 }
656
657 suggestions
658 }
659
660 pub fn suggestions_to_check(
662 &self,
663 suggestions: &[CrossTableSuggestion],
664 check_name: &str,
665 ) -> Check {
666 let mut builder = Check::builder(check_name);
667
668 for suggestion in suggestions {
669 match suggestion.constraint_type.as_str() {
670 "foreign_key" => {
671 if suggestion.tables.len() == 2 {
672 let child_col = format!(
673 "{}.{}",
674 suggestion.tables[0], suggestion.columns[&suggestion.tables[0]][0]
675 );
676 let parent_col = format!(
677 "{}.{}",
678 suggestion.tables[1], suggestion.columns[&suggestion.tables[1]][0]
679 );
680 builder = builder.foreign_key(child_col, parent_col);
681 }
682 }
683 "cross_table_sum" => {
684 if suggestion.tables.len() == 2 {
685 let left_col = format!(
686 "{}.{}",
687 suggestion.tables[0], suggestion.columns[&suggestion.tables[0]][0]
688 );
689 let right_col = format!(
690 "{}.{}",
691 suggestion.tables[1], suggestion.columns[&suggestion.tables[1]][0]
692 );
693 builder = builder.cross_table_sum(left_col, right_col);
694 }
695 }
696 "join_coverage" => {
697 if suggestion.tables.len() == 2 {
698 builder =
699 builder.join_coverage(&suggestion.tables[0], &suggestion.tables[1]);
700 }
701 }
702 "temporal_ordering" => {
703 if !suggestion.tables.is_empty() {
704 builder = builder.temporal_ordering(&suggestion.tables[0]);
705 }
706 }
707 _ => {}
708 }
709 }
710
711 builder.build()
712 }
713}
714
715#[cfg(test)]
716mod tests {
717 use super::*;
718 use datafusion::arrow::datatypes::{Field, Schema as ArrowSchema};
719
720 #[test]
721 fn test_foreign_key_detection() {
722 let ctx = SessionContext::new();
723 let analyzer = SchemaAnalyzer::new(&ctx);
724 let mut schemas = HashMap::new();
725
726 let orders_schema = Arc::new(ArrowSchema::new(vec![
728 Field::new("id", DataType::Int64, false),
729 Field::new("customer_id", DataType::Int64, false),
730 Field::new("total", DataType::Float64, false),
731 ]));
732 schemas.insert("orders".to_string(), orders_schema);
733
734 let customers_schema = Arc::new(ArrowSchema::new(vec![
736 Field::new("id", DataType::Int64, false),
737 Field::new("name", DataType::Utf8, false),
738 ]));
739 schemas.insert("customers".to_string(), customers_schema);
740
741 let suggestions = analyzer.analyze_foreign_keys(&schemas);
742
743 assert!(!suggestions.is_empty());
744 assert_eq!(suggestions[0].constraint_type, "foreign_key");
745 assert!(suggestions[0].tables.contains(&"orders".to_string()));
746 assert!(suggestions[0].tables.contains(&"customers".to_string()));
747 }
748
749 #[test]
750 fn test_temporal_column_detection() {
751 let ctx = SessionContext::new();
752 let analyzer = SchemaAnalyzer::new(&ctx);
753
754 let schema = Arc::new(ArrowSchema::new(vec![
755 Field::new("id", DataType::Int64, false),
756 Field::new(
757 "created_at",
758 DataType::Timestamp(datafusion::arrow::datatypes::TimeUnit::Microsecond, None),
759 false,
760 ),
761 Field::new(
762 "updated_at",
763 DataType::Timestamp(datafusion::arrow::datatypes::TimeUnit::Microsecond, None),
764 false,
765 ),
766 Field::new("name", DataType::Utf8, false),
767 ]));
768
769 let temporal_cols = analyzer.find_temporal_columns(&schema);
770
771 assert_eq!(temporal_cols.len(), 2);
772 assert!(temporal_cols.contains(&"created_at".to_string()));
773 assert!(temporal_cols.contains(&"updated_at".to_string()));
774 }
775
776 #[test]
777 fn test_temporal_ordering() {
778 let ctx = SessionContext::new();
779 let analyzer = SchemaAnalyzer::new(&ctx);
780
781 let (before, after) = analyzer.infer_temporal_order("created_at", "updated_at");
782 assert_eq!(before, "created_at");
783 assert_eq!(after, "updated_at");
784
785 let (before, after) = analyzer.infer_temporal_order("processed_at", "created_at");
786 assert_eq!(before, "created_at");
787 assert_eq!(after, "processed_at");
788 }
789
790 #[test]
791 fn test_amount_column_detection() {
792 let ctx = SessionContext::new();
793 let analyzer = SchemaAnalyzer::new(&ctx);
794
795 assert!(analyzer.is_amount_column("total_amount", &DataType::Float64));
796 assert!(analyzer.is_amount_column("price", &DataType::Decimal128(10, 2)));
797 assert!(!analyzer.is_amount_column("customer_id", &DataType::Int64));
798 assert!(!analyzer.is_amount_column("total", &DataType::Utf8));
799 }
800}