term_guard/constraints/
column_count.rs

1//! Column count validation constraint.
2
3use crate::constraints::Assertion;
4use crate::core::{Constraint, ConstraintMetadata, ConstraintResult, ConstraintStatus};
5use crate::prelude::*;
6use async_trait::async_trait;
7use datafusion::prelude::*;
8use tracing::instrument;
9/// A constraint that validates the number of columns in a dataset.
10///
11/// This constraint checks if the column count of the dataset meets the specified assertion.
12/// It uses DataFusion's schema to determine the number of columns.
13///
14/// # Examples
15///
16/// ```rust
17/// use term_guard::constraints::{Assertion, ColumnCountConstraint};
18/// use term_guard::core::Constraint;
19///
20/// // Check for exactly 15 columns
21/// let constraint = ColumnCountConstraint::new(Assertion::Equals(15.0));
22/// assert_eq!(constraint.name(), "column_count");
23///
24/// // Check for at least 10 columns
25/// let constraint = ColumnCountConstraint::new(Assertion::GreaterThanOrEqual(10.0));
26/// ```
27#[derive(Debug, Clone)]
28pub struct ColumnCountConstraint {
29    assertion: Assertion,
30}
31
32impl ColumnCountConstraint {
33    /// Creates a new column count constraint.
34    ///
35    /// # Arguments
36    ///
37    /// * `assertion` - The assertion to apply to the column count
38    pub fn new(assertion: Assertion) -> Self {
39        Self { assertion }
40    }
41}
42
43#[async_trait]
44impl Constraint for ColumnCountConstraint {
45    #[instrument(skip(self, ctx), fields(assertion = ?self.assertion))]
46    async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
47        // Get the table from the context
48        let df = ctx.table("data").await.map_err(|e| {
49            TermError::constraint_evaluation(
50                self.name(),
51                format!("Failed to access table 'data': {e}"),
52            )
53        })?;
54
55        // Get the column count from the schema
56        let column_count = df.schema().fields().len() as f64;
57
58        // Evaluate the assertion
59        let assertion_result = self.assertion.evaluate(column_count);
60
61        let status = if assertion_result {
62            ConstraintStatus::Success
63        } else {
64            ConstraintStatus::Failure
65        };
66
67        let message = if status == ConstraintStatus::Failure {
68            Some(format!(
69                "Column count {column_count} does not satisfy assertion {}",
70                self.assertion.description()
71            ))
72        } else {
73            None
74        };
75
76        Ok(ConstraintResult {
77            status,
78            metric: Some(column_count),
79            message,
80        })
81    }
82
83    fn name(&self) -> &str {
84        "column_count"
85    }
86
87    fn column(&self) -> Option<&str> {
88        None // This constraint operates on the entire dataset, not a specific column
89    }
90
91    fn metadata(&self) -> ConstraintMetadata {
92        ConstraintMetadata::default()
93            .with_description(format!(
94                "Checks that the dataset has {} columns",
95                self.assertion.description()
96            ))
97            .with_custom("assertion", self.assertion.description())
98            .with_custom("constraint_type", "schema")
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use crate::core::ConstraintStatus;
106    use arrow::array::Int64Array;
107    use arrow::datatypes::{DataType, Field, Schema};
108    use arrow::record_batch::RecordBatch;
109    use datafusion::datasource::MemTable;
110    use std::sync::Arc;
111
112    use crate::test_helpers::evaluate_constraint_with_context;
113    async fn create_test_context_with_columns(num_columns: usize) -> SessionContext {
114        let ctx = SessionContext::new();
115
116        // Create a schema with the specified number of columns
117        let fields: Vec<Field> = (0..num_columns)
118            .map(|i| Field::new(format!("col_{i}"), DataType::Int64, true))
119            .collect();
120
121        let schema = Arc::new(Schema::new(fields));
122
123        // Create arrays for each column
124        let arrays: Vec<Arc<dyn arrow::array::Array>> = (0..num_columns)
125            .map(|_| {
126                Arc::new(Int64Array::from(vec![Some(1), Some(2), Some(3)]))
127                    as Arc<dyn arrow::array::Array>
128            })
129            .collect();
130
131        let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
132        let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
133        ctx.register_table("data", Arc::new(provider)).unwrap();
134
135        ctx
136    }
137
138    #[tokio::test]
139    async fn test_column_count_equals() {
140        let ctx = create_test_context_with_columns(5).await;
141
142        // Test exact match
143        let constraint = ColumnCountConstraint::new(Assertion::Equals(5.0));
144        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
145            .await
146            .unwrap();
147        assert_eq!(result.status, ConstraintStatus::Success);
148        assert_eq!(result.metric, Some(5.0));
149
150        // Test mismatch
151        let constraint = ColumnCountConstraint::new(Assertion::Equals(10.0));
152        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
153            .await
154            .unwrap();
155        assert_eq!(result.status, ConstraintStatus::Failure);
156        assert_eq!(result.metric, Some(5.0));
157        assert!(result.message.is_some());
158    }
159
160    #[tokio::test]
161    async fn test_column_count_greater_than() {
162        let ctx = create_test_context_with_columns(8).await;
163
164        // Test success case
165        let constraint = ColumnCountConstraint::new(Assertion::GreaterThan(5.0));
166        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
167            .await
168            .unwrap();
169        assert_eq!(result.status, ConstraintStatus::Success);
170
171        // Test failure case
172        let constraint = ColumnCountConstraint::new(Assertion::GreaterThan(10.0));
173        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
174            .await
175            .unwrap();
176        assert_eq!(result.status, ConstraintStatus::Failure);
177    }
178
179    #[tokio::test]
180    async fn test_column_count_less_than() {
181        let ctx = create_test_context_with_columns(3).await;
182
183        // Test success case
184        let constraint = ColumnCountConstraint::new(Assertion::LessThan(5.0));
185        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
186            .await
187            .unwrap();
188        assert_eq!(result.status, ConstraintStatus::Success);
189
190        // Test failure case
191        let constraint = ColumnCountConstraint::new(Assertion::LessThan(2.0));
192        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
193            .await
194            .unwrap();
195        assert_eq!(result.status, ConstraintStatus::Failure);
196    }
197
198    #[tokio::test]
199    async fn test_column_count_between() {
200        let ctx = create_test_context_with_columns(7).await;
201
202        // Test within range
203        let constraint = ColumnCountConstraint::new(Assertion::Between(5.0, 10.0));
204        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
205            .await
206            .unwrap();
207        assert_eq!(result.status, ConstraintStatus::Success);
208
209        // Test outside range
210        let constraint = ColumnCountConstraint::new(Assertion::Between(10.0, 15.0));
211        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
212            .await
213            .unwrap();
214        assert_eq!(result.status, ConstraintStatus::Failure);
215    }
216
217    #[tokio::test]
218    async fn test_single_column_dataset() {
219        let ctx = create_test_context_with_columns(1).await;
220
221        let constraint = ColumnCountConstraint::new(Assertion::Equals(1.0));
222        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
223            .await
224            .unwrap();
225        assert_eq!(result.status, ConstraintStatus::Success);
226        assert_eq!(result.metric, Some(1.0));
227    }
228
229    #[tokio::test]
230    async fn test_large_column_count() {
231        let ctx = create_test_context_with_columns(100).await;
232
233        let constraint = ColumnCountConstraint::new(Assertion::GreaterThanOrEqual(100.0));
234        let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
235            .await
236            .unwrap();
237        assert_eq!(result.status, ConstraintStatus::Success);
238        assert_eq!(result.metric, Some(100.0));
239    }
240
241    #[tokio::test]
242    async fn test_missing_table() {
243        let ctx = SessionContext::new();
244        // Don't register any table
245
246        let constraint = ColumnCountConstraint::new(Assertion::Equals(5.0));
247        let result = evaluate_constraint_with_context(&constraint, &ctx, "data").await;
248        assert!(result.is_err());
249    }
250
251    #[tokio::test]
252    async fn test_metadata() {
253        let constraint = ColumnCountConstraint::new(Assertion::Between(10.0, 20.0));
254        let metadata = constraint.metadata();
255
256        assert!(metadata
257            .description
258            .unwrap_or_default()
259            .contains("between 10 and 20"));
260        assert_eq!(
261            metadata.custom.get("constraint_type"),
262            Some(&"schema".to_string())
263        );
264    }
265}