term_guard/constraints/
column_count.rs

1//! Column count validation constraint.
2
3use crate::constraints::Assertion;
4use crate::core::{Constraint, ConstraintMetadata, ConstraintResult, ConstraintStatus};
5use crate::prelude::*;
6use async_trait::async_trait;
7use datafusion::prelude::*;
8use tracing::instrument;
9
10/// A constraint that validates the number of columns in a dataset.
11///
12/// This constraint checks if the column count of the dataset meets the specified assertion.
13/// It uses DataFusion's schema to determine the number of columns.
14///
15/// # Examples
16///
17/// ```rust
18/// use term_guard::constraints::{Assertion, ColumnCountConstraint};
19/// use term_guard::core::Constraint;
20///
21/// // Check for exactly 15 columns
22/// let constraint = ColumnCountConstraint::new(Assertion::Equals(15.0));
23/// assert_eq!(constraint.name(), "column_count");
24///
25/// // Check for at least 10 columns
26/// let constraint = ColumnCountConstraint::new(Assertion::GreaterThanOrEqual(10.0));
27/// ```
28#[derive(Debug, Clone)]
29pub struct ColumnCountConstraint {
30    assertion: Assertion,
31}
32
33impl ColumnCountConstraint {
34    /// Creates a new column count constraint.
35    ///
36    /// # Arguments
37    ///
38    /// * `assertion` - The assertion to apply to the column count
39    pub fn new(assertion: Assertion) -> Self {
40        Self { assertion }
41    }
42}
43
44#[async_trait]
45impl Constraint for ColumnCountConstraint {
46    #[instrument(skip(self, ctx), fields(assertion = ?self.assertion))]
47    async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
48        // Get the table from the context
49        let df = ctx.table("data").await.map_err(|e| {
50            TermError::constraint_evaluation(
51                self.name(),
52                format!("Failed to access table 'data': {e}"),
53            )
54        })?;
55
56        // Get the column count from the schema
57        let column_count = df.schema().fields().len() as f64;
58
59        // Evaluate the assertion
60        let assertion_result = self.assertion.evaluate(column_count);
61
62        let status = if assertion_result {
63            ConstraintStatus::Success
64        } else {
65            ConstraintStatus::Failure
66        };
67
68        let message = if status == ConstraintStatus::Failure {
69            Some(format!(
70                "Column count {column_count} does not satisfy assertion {}",
71                self.assertion.description()
72            ))
73        } else {
74            None
75        };
76
77        Ok(ConstraintResult {
78            status,
79            metric: Some(column_count),
80            message,
81        })
82    }
83
84    fn name(&self) -> &str {
85        "column_count"
86    }
87
88    fn column(&self) -> Option<&str> {
89        None // This constraint operates on the entire dataset, not a specific column
90    }
91
92    fn metadata(&self) -> ConstraintMetadata {
93        ConstraintMetadata::default()
94            .with_description(format!(
95                "Checks that the dataset has {} columns",
96                self.assertion.description()
97            ))
98            .with_custom("assertion", self.assertion.description())
99            .with_custom("constraint_type", "schema")
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use crate::core::ConstraintStatus;
107    use arrow::array::Int64Array;
108    use arrow::datatypes::{DataType, Field, Schema};
109    use arrow::record_batch::RecordBatch;
110    use datafusion::datasource::MemTable;
111    use std::sync::Arc;
112
113    async fn create_test_context_with_columns(num_columns: usize) -> SessionContext {
114        let ctx = SessionContext::new();
115
116        // Create a schema with the specified number of columns
117        let fields: Vec<Field> = (0..num_columns)
118            .map(|i| Field::new(format!("col_{i}"), DataType::Int64, true))
119            .collect();
120
121        let schema = Arc::new(Schema::new(fields));
122
123        // Create arrays for each column
124        let arrays: Vec<Arc<dyn arrow::array::Array>> = (0..num_columns)
125            .map(|_| {
126                Arc::new(Int64Array::from(vec![Some(1), Some(2), Some(3)]))
127                    as Arc<dyn arrow::array::Array>
128            })
129            .collect();
130
131        let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
132        let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
133        ctx.register_table("data", Arc::new(provider)).unwrap();
134
135        ctx
136    }
137
138    #[tokio::test]
139    async fn test_column_count_equals() {
140        let ctx = create_test_context_with_columns(5).await;
141
142        // Test exact match
143        let constraint = ColumnCountConstraint::new(Assertion::Equals(5.0));
144        let result = constraint.evaluate(&ctx).await.unwrap();
145        assert_eq!(result.status, ConstraintStatus::Success);
146        assert_eq!(result.metric, Some(5.0));
147
148        // Test mismatch
149        let constraint = ColumnCountConstraint::new(Assertion::Equals(10.0));
150        let result = constraint.evaluate(&ctx).await.unwrap();
151        assert_eq!(result.status, ConstraintStatus::Failure);
152        assert_eq!(result.metric, Some(5.0));
153        assert!(result.message.is_some());
154    }
155
156    #[tokio::test]
157    async fn test_column_count_greater_than() {
158        let ctx = create_test_context_with_columns(8).await;
159
160        // Test success case
161        let constraint = ColumnCountConstraint::new(Assertion::GreaterThan(5.0));
162        let result = constraint.evaluate(&ctx).await.unwrap();
163        assert_eq!(result.status, ConstraintStatus::Success);
164
165        // Test failure case
166        let constraint = ColumnCountConstraint::new(Assertion::GreaterThan(10.0));
167        let result = constraint.evaluate(&ctx).await.unwrap();
168        assert_eq!(result.status, ConstraintStatus::Failure);
169    }
170
171    #[tokio::test]
172    async fn test_column_count_less_than() {
173        let ctx = create_test_context_with_columns(3).await;
174
175        // Test success case
176        let constraint = ColumnCountConstraint::new(Assertion::LessThan(5.0));
177        let result = constraint.evaluate(&ctx).await.unwrap();
178        assert_eq!(result.status, ConstraintStatus::Success);
179
180        // Test failure case
181        let constraint = ColumnCountConstraint::new(Assertion::LessThan(2.0));
182        let result = constraint.evaluate(&ctx).await.unwrap();
183        assert_eq!(result.status, ConstraintStatus::Failure);
184    }
185
186    #[tokio::test]
187    async fn test_column_count_between() {
188        let ctx = create_test_context_with_columns(7).await;
189
190        // Test within range
191        let constraint = ColumnCountConstraint::new(Assertion::Between(5.0, 10.0));
192        let result = constraint.evaluate(&ctx).await.unwrap();
193        assert_eq!(result.status, ConstraintStatus::Success);
194
195        // Test outside range
196        let constraint = ColumnCountConstraint::new(Assertion::Between(10.0, 15.0));
197        let result = constraint.evaluate(&ctx).await.unwrap();
198        assert_eq!(result.status, ConstraintStatus::Failure);
199    }
200
201    #[tokio::test]
202    async fn test_single_column_dataset() {
203        let ctx = create_test_context_with_columns(1).await;
204
205        let constraint = ColumnCountConstraint::new(Assertion::Equals(1.0));
206        let result = constraint.evaluate(&ctx).await.unwrap();
207        assert_eq!(result.status, ConstraintStatus::Success);
208        assert_eq!(result.metric, Some(1.0));
209    }
210
211    #[tokio::test]
212    async fn test_large_column_count() {
213        let ctx = create_test_context_with_columns(100).await;
214
215        let constraint = ColumnCountConstraint::new(Assertion::GreaterThanOrEqual(100.0));
216        let result = constraint.evaluate(&ctx).await.unwrap();
217        assert_eq!(result.status, ConstraintStatus::Success);
218        assert_eq!(result.metric, Some(100.0));
219    }
220
221    #[tokio::test]
222    async fn test_missing_table() {
223        let ctx = SessionContext::new();
224        // Don't register any table
225
226        let constraint = ColumnCountConstraint::new(Assertion::Equals(5.0));
227        let result = constraint.evaluate(&ctx).await;
228        assert!(result.is_err());
229    }
230
231    #[tokio::test]
232    async fn test_metadata() {
233        let constraint = ColumnCountConstraint::new(Assertion::Between(10.0, 20.0));
234        let metadata = constraint.metadata();
235
236        assert!(metadata
237            .description
238            .unwrap_or_default()
239            .contains("between 10 and 20"));
240        assert_eq!(
241            metadata.custom.get("constraint_type"),
242            Some(&"schema".to_string())
243        );
244    }
245}