term_guard/sources/
joined.rs

1//! Joined data sources for cross-table validation in Term.
2//!
3//! This module provides infrastructure for validating data relationships across multiple tables,
4//! enabling foreign key validation, referential integrity checks, and cross-table consistency
5//! rules that are not possible with single-table validation approaches.
6//!
7//! # Examples
8//!
9//! ## Basic Foreign Key Validation
10//!
11//! ```rust,no_run
12//! use term_guard::sources::{JoinedSource, JoinType, DataSource, CsvSource};
13//! use term_guard::core::{Check, Level};
14//!
15//! // Create joined source for orders -> customers relationship
16//! let joined_source = JoinedSource::builder()
17//!     .left_source(CsvSource::new("orders.csv").unwrap(), "orders")
18//!     .right_source(CsvSource::new("customers.csv").unwrap(), "customers")
19//!     .join_type(JoinType::Inner)
20//!     .on("customer_id", "id")
21//!     .build()
22//!     .unwrap();
23//! ```
24
25use crate::error::{Result, TermError};
26use crate::sources::DataSource;
27use crate::telemetry::TermTelemetry;
28use arrow::datatypes::Schema;
29use async_trait::async_trait;
30use datafusion::prelude::*;
31use serde::{Deserialize, Serialize};
32use std::sync::Arc;
33use tracing::{debug, info, instrument};
34
35/// Types of joins supported for cross-table validation
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
37pub enum JoinType {
38    /// Inner join - only rows that match in both tables
39    Inner,
40    /// Left join - all rows from left table, matching rows from right
41    Left,
42    /// Right join - all rows from right table, matching rows from left  
43    Right,
44    /// Full outer join - all rows from both tables
45    Full,
46}
47
48impl JoinType {
49    /// Convert to SQL join syntax
50    pub fn to_sql(&self) -> &'static str {
51        match self {
52            JoinType::Inner => "INNER JOIN",
53            JoinType::Left => "LEFT JOIN",
54            JoinType::Right => "RIGHT JOIN",
55            JoinType::Full => "FULL OUTER JOIN",
56        }
57    }
58}
59
60/// Join condition specifying how tables are connected
61#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
62pub struct JoinCondition {
63    /// Column from the left table
64    pub left_column: String,
65    /// Column from the right table  
66    pub right_column: String,
67    /// Type of join to perform
68    pub join_type: JoinType,
69}
70
71impl JoinCondition {
72    /// Create a new join condition
73    pub fn new(
74        left_column: impl Into<String>,
75        right_column: impl Into<String>,
76        join_type: JoinType,
77    ) -> Self {
78        Self {
79            left_column: left_column.into(),
80            right_column: right_column.into(),
81            join_type,
82        }
83    }
84
85    /// Generate SQL join clause
86    pub fn to_sql(&self, left_alias: &str, right_alias: &str) -> String {
87        format!(
88            "{} ON {left_alias}.{} = {right_alias}.{}",
89            self.join_type.to_sql(),
90            self.left_column,
91            self.right_column
92        )
93    }
94}
95
96/// A data source that represents the join of multiple tables
97#[derive(Debug, Clone)]
98pub struct JoinedSource {
99    /// Primary (left) data source
100    left_source: Arc<dyn DataSource>,
101    /// Left table alias
102    left_alias: String,
103    /// Secondary (right) data source
104    right_source: Arc<dyn DataSource>,
105    /// Right table alias
106    right_alias: String,
107    /// Join condition
108    join_condition: JoinCondition,
109    /// Optional WHERE clause to filter joined results
110    where_clause: Option<String>,
111    /// Additional join stages for multi-table joins
112    additional_joins: Vec<AdditionalJoin>,
113}
114
115/// Additional join for multi-table scenarios
116#[derive(Debug, Clone)]
117struct AdditionalJoin {
118    source: Arc<dyn DataSource>,
119    alias: String,
120    condition: JoinCondition,
121}
122
123impl JoinedSource {
124    /// Create a new builder for configuring joined sources
125    pub fn builder() -> JoinedSourceBuilder {
126        JoinedSourceBuilder::new()
127    }
128
129    /// Get the left table alias
130    pub fn left_alias(&self) -> &str {
131        &self.left_alias
132    }
133
134    /// Get the right table alias  
135    pub fn right_alias(&self) -> &str {
136        &self.right_alias
137    }
138
139    /// Get the join condition
140    pub fn join_condition(&self) -> &JoinCondition {
141        &self.join_condition
142    }
143
144    /// Generate SQL query for the joined tables
145    #[instrument(skip(self))]
146    pub fn generate_sql(&self, table_name: &str) -> String {
147        let join_type_sql = self.join_condition.join_type.to_sql();
148
149        // When joining tables, we need to check if the column already has a table prefix
150        // If it does (e.g., "orders.customer_id"), use it as-is
151        // If it doesn't (e.g., "customer_id"), add the table alias
152        let left_col = if self.join_condition.left_column.contains('.') {
153            self.join_condition.left_column.clone()
154        } else {
155            format!("{}.{}", self.left_alias, self.join_condition.left_column)
156        };
157
158        let right_col = if self.join_condition.right_column.contains('.') {
159            self.join_condition.right_column.clone()
160        } else {
161            format!("{}.{}", self.right_alias, self.join_condition.right_column)
162        };
163
164        let on_clause = format!("ON {left_col} = {right_col}");
165
166        let mut sql = format!(
167            "CREATE OR REPLACE VIEW {table_name} AS SELECT * FROM {} {join_type_sql} {} {on_clause}",
168            self.left_alias,
169            self.right_alias
170        );
171
172        // Add additional joins if any
173        for additional in &self.additional_joins {
174            sql.push(' ');
175            sql.push_str(
176                &additional
177                    .condition
178                    .to_sql(&self.left_alias, &additional.alias),
179            );
180            sql.push(' ');
181            sql.push_str(&additional.alias);
182        }
183
184        // Add WHERE clause if specified
185        if let Some(where_clause) = &self.where_clause {
186            sql.push_str(" WHERE ");
187            sql.push_str(where_clause);
188        }
189
190        debug!("Generated SQL for joined source: {}", sql);
191        sql
192    }
193}
194
195#[async_trait]
196impl DataSource for JoinedSource {
197    #[instrument(skip(self, ctx))]
198    async fn register_with_telemetry(
199        &self,
200        ctx: &SessionContext,
201        table_name: &str,
202        telemetry: Option<&Arc<TermTelemetry>>,
203    ) -> Result<()> {
204        info!("Registering joined source as table: {}", table_name);
205
206        // Register individual sources with their aliases
207        self.left_source
208            .register_with_telemetry(ctx, &self.left_alias, telemetry)
209            .await
210            .map_err(|e| {
211                TermError::data_source(
212                    "joined",
213                    format!("Failed to register left source '{}': {e}", self.left_alias),
214                )
215            })?;
216
217        self.right_source
218            .register_with_telemetry(ctx, &self.right_alias, telemetry)
219            .await
220            .map_err(|e| {
221                TermError::data_source(
222                    "joined",
223                    format!(
224                        "Failed to register right source '{}': {e}",
225                        self.right_alias
226                    ),
227                )
228            })?;
229
230        // Register additional join sources
231        for additional in &self.additional_joins {
232            additional
233                .source
234                .register_with_telemetry(ctx, &additional.alias, telemetry)
235                .await
236                .map_err(|e| {
237                    TermError::data_source(
238                        "joined",
239                        format!(
240                            "Failed to register additional source '{}': {e}",
241                            additional.alias
242                        ),
243                    )
244                })?;
245        }
246
247        // Create the joined view
248        let sql = self.generate_sql(table_name);
249        ctx.sql(&sql).await.map_err(|e| {
250            TermError::data_source("joined", format!("Failed to create joined view: {e}"))
251        })?;
252
253        info!("Successfully registered joined source: {}", table_name);
254        Ok(())
255    }
256
257    fn schema(&self) -> Option<&Arc<Schema>> {
258        // For joined sources, we don't have a pre-computed schema
259        // The schema will be determined when the view is created
260        None
261    }
262
263    fn description(&self) -> String {
264        format!(
265            "Joined source: {} {} {} ON {}.{} = {}.{}",
266            self.left_alias,
267            self.join_condition.join_type.to_sql(),
268            self.right_alias,
269            self.left_alias,
270            self.join_condition.left_column,
271            self.right_alias,
272            self.join_condition.right_column
273        )
274    }
275}
276
277/// Builder for configuring joined data sources
278pub struct JoinedSourceBuilder {
279    left_source: Option<(Arc<dyn DataSource>, String)>,
280    right_source: Option<(Arc<dyn DataSource>, String)>,
281    join_condition: Option<JoinCondition>,
282    where_clause: Option<String>,
283    additional_joins: Vec<AdditionalJoin>,
284}
285
286impl JoinedSourceBuilder {
287    fn new() -> Self {
288        Self {
289            left_source: None,
290            right_source: None,
291            join_condition: None,
292            where_clause: None,
293            additional_joins: Vec::new(),
294        }
295    }
296
297    /// Set the left (primary) data source
298    pub fn left_source<S: DataSource + 'static>(
299        mut self,
300        source: S,
301        alias: impl Into<String>,
302    ) -> Self {
303        self.left_source = Some((Arc::new(source), alias.into()));
304        self
305    }
306
307    /// Set the right (secondary) data source
308    pub fn right_source<S: DataSource + 'static>(
309        mut self,
310        source: S,
311        alias: impl Into<String>,
312    ) -> Self {
313        self.right_source = Some((Arc::new(source), alias.into()));
314        self
315    }
316
317    /// Set the join condition with inner join
318    pub fn on(mut self, left_column: impl Into<String>, right_column: impl Into<String>) -> Self {
319        self.join_condition = Some(JoinCondition::new(
320            left_column,
321            right_column,
322            JoinType::Inner,
323        ));
324        self
325    }
326
327    /// Set the join condition with specified join type
328    pub fn join_on(
329        mut self,
330        left_column: impl Into<String>,
331        right_column: impl Into<String>,
332        join_type: JoinType,
333    ) -> Self {
334        self.join_condition = Some(JoinCondition::new(left_column, right_column, join_type));
335        self
336    }
337
338    /// Set the join type (defaults to inner join)
339    pub fn join_type(mut self, join_type: JoinType) -> Self {
340        if let Some(ref mut condition) = self.join_condition {
341            condition.join_type = join_type;
342        }
343        self
344    }
345
346    /// Add a WHERE clause to filter joined results
347    pub fn where_clause(mut self, clause: impl Into<String>) -> Self {
348        self.where_clause = Some(clause.into());
349        self
350    }
351
352    /// Add an additional join for multi-table scenarios
353    pub fn additional_join<S: DataSource + 'static>(
354        mut self,
355        source: S,
356        alias: impl Into<String>,
357        left_column: impl Into<String>,
358        right_column: impl Into<String>,
359        join_type: JoinType,
360    ) -> Self {
361        self.additional_joins.push(AdditionalJoin {
362            source: Arc::new(source),
363            alias: alias.into(),
364            condition: JoinCondition::new(left_column, right_column, join_type),
365        });
366        self
367    }
368
369    /// Build the joined source
370    pub fn build(self) -> Result<JoinedSource> {
371        let left_source = self
372            .left_source
373            .ok_or_else(|| TermError::data_source("joined", "Left source is required"))?;
374        let right_source = self
375            .right_source
376            .ok_or_else(|| TermError::data_source("joined", "Right source is required"))?;
377        let join_condition = self
378            .join_condition
379            .ok_or_else(|| TermError::data_source("joined", "Join condition is required"))?;
380
381        Ok(JoinedSource {
382            left_source: left_source.0,
383            left_alias: left_source.1,
384            right_source: right_source.0,
385            right_alias: right_source.1,
386            join_condition,
387            where_clause: self.where_clause,
388            additional_joins: self.additional_joins,
389        })
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use crate::sources::CsvSource;
397    use std::io::Write;
398    use tempfile::NamedTempFile;
399
400    fn create_test_csv(data: &str) -> Result<NamedTempFile> {
401        let mut temp_file = NamedTempFile::with_suffix(".csv")?;
402        write!(temp_file, "{data}")?;
403        temp_file.flush()?;
404        Ok(temp_file)
405    }
406
407    #[test]
408    fn test_join_type_sql() {
409        assert_eq!(JoinType::Inner.to_sql(), "INNER JOIN");
410        assert_eq!(JoinType::Left.to_sql(), "LEFT JOIN");
411        assert_eq!(JoinType::Right.to_sql(), "RIGHT JOIN");
412        assert_eq!(JoinType::Full.to_sql(), "FULL OUTER JOIN");
413    }
414
415    #[test]
416    fn test_join_condition_sql() {
417        let condition = JoinCondition::new("customer_id", "id", JoinType::Inner);
418        assert_eq!(
419            condition.to_sql("orders", "customers"),
420            "INNER JOIN ON orders.customer_id = customers.id"
421        );
422    }
423
424    #[tokio::test]
425    async fn test_joined_source_builder() -> Result<()> {
426        let orders_data = "order_id,customer_id,amount\n1,1,100.0\n2,2,200.0";
427        let customers_data = "id,name\n1,Alice\n2,Bob";
428
429        let orders_file = create_test_csv(orders_data)?;
430        let customers_file = create_test_csv(customers_data)?;
431
432        let orders_source = CsvSource::new(orders_file.path().to_string_lossy().to_string())?;
433        let customers_source = CsvSource::new(customers_file.path().to_string_lossy().to_string())?;
434
435        let joined_source = JoinedSource::builder()
436            .left_source(orders_source, "orders")
437            .right_source(customers_source, "customers")
438            .on("customer_id", "id")
439            .build()?;
440
441        assert_eq!(joined_source.left_alias(), "orders");
442        assert_eq!(joined_source.right_alias(), "customers");
443        assert_eq!(joined_source.join_condition().join_type, JoinType::Inner);
444
445        Ok(())
446    }
447
448    #[tokio::test]
449    async fn test_joined_source_registration() -> Result<()> {
450        let orders_data = "order_id,customer_id,amount\n1,1,100.0\n2,2,200.0\n3,999,300.0";
451        let customers_data = "id,name\n1,Alice\n2,Bob";
452
453        let orders_file = create_test_csv(orders_data)?;
454        let customers_file = create_test_csv(customers_data)?;
455
456        let orders_source = CsvSource::new(orders_file.path().to_string_lossy().to_string())?;
457        let customers_source = CsvSource::new(customers_file.path().to_string_lossy().to_string())?;
458
459        let joined_source = JoinedSource::builder()
460            .left_source(orders_source, "orders")
461            .right_source(customers_source, "customers")
462            .join_on("customer_id", "id", JoinType::Left)
463            .build()?;
464
465        let ctx = SessionContext::new();
466        joined_source
467            .register(&ctx, "orders_with_customers")
468            .await?;
469
470        // Verify the joined table exists and can be queried
471        let df = ctx
472            .sql("SELECT COUNT(*) as count FROM orders_with_customers")
473            .await?;
474        let results = df.collect().await?;
475
476        // Should have 3 rows (including the one with missing customer)
477        assert_eq!(results.len(), 1);
478
479        Ok(())
480    }
481
482    #[test]
483    fn test_joined_source_sql_generation() -> Result<()> {
484        let orders_data = "id,customer_id,amount\n1,1,100.0";
485        let customers_data = "id,name\n1,Alice";
486
487        let orders_file = create_test_csv(orders_data)?;
488        let customers_file = create_test_csv(customers_data)?;
489
490        let orders_source = CsvSource::new(orders_file.path().to_string_lossy().to_string())?;
491        let customers_source = CsvSource::new(customers_file.path().to_string_lossy().to_string())?;
492
493        let joined_source = JoinedSource::builder()
494            .left_source(orders_source, "orders")
495            .right_source(customers_source, "customers")
496            .on("customer_id", "id")
497            .where_clause("orders.amount > 50")
498            .build()?;
499
500        let sql = joined_source.generate_sql("test_view");
501
502        assert!(sql.contains("CREATE OR REPLACE VIEW test_view"));
503        assert!(sql.contains("INNER JOIN"));
504        assert!(sql.contains("orders.customer_id = customers.id"));
505        assert!(sql.contains("WHERE orders.amount > 50"));
506
507        Ok(())
508    }
509
510    #[test]
511    fn test_joined_source_description() -> Result<()> {
512        let orders_data = "id,customer_id,amount\n1,1,100.0";
513        let customers_data = "id,name\n1,Alice";
514
515        let orders_file = create_test_csv(orders_data)?;
516        let customers_file = create_test_csv(customers_data)?;
517
518        let orders_source = CsvSource::new(orders_file.path().to_string_lossy().to_string())?;
519        let customers_source = CsvSource::new(customers_file.path().to_string_lossy().to_string())?;
520
521        let joined_source = JoinedSource::builder()
522            .left_source(orders_source, "orders")
523            .right_source(customers_source, "customers")
524            .join_on("customer_id", "id", JoinType::Left)
525            .build()?;
526
527        let description = joined_source.description();
528        assert!(description.contains("orders"));
529        assert!(description.contains("LEFT JOIN"));
530        assert!(description.contains("customers"));
531        assert!(description.contains("customer_id"));
532
533        Ok(())
534    }
535}