term_guard/constraints/
column_count.rs1use crate::constraints::Assertion;
4use crate::core::{Constraint, ConstraintMetadata, ConstraintResult, ConstraintStatus};
5use crate::prelude::*;
6use async_trait::async_trait;
7use datafusion::prelude::*;
8use tracing::instrument;
9
10#[derive(Debug, Clone)]
29pub struct ColumnCountConstraint {
30 assertion: Assertion,
31}
32
33impl ColumnCountConstraint {
34 pub fn new(assertion: Assertion) -> Self {
40 Self { assertion }
41 }
42}
43
44#[async_trait]
45impl Constraint for ColumnCountConstraint {
46 #[instrument(skip(self, ctx), fields(assertion = ?self.assertion))]
47 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
48 let df = ctx.table("data").await.map_err(|e| {
50 TermError::constraint_evaluation(
51 self.name(),
52 format!("Failed to access table 'data': {e}"),
53 )
54 })?;
55
56 let column_count = df.schema().fields().len() as f64;
58
59 let assertion_result = self.assertion.evaluate(column_count);
61
62 let status = if assertion_result {
63 ConstraintStatus::Success
64 } else {
65 ConstraintStatus::Failure
66 };
67
68 let message = if status == ConstraintStatus::Failure {
69 Some(format!(
70 "Column count {column_count} does not satisfy assertion {}",
71 self.assertion.description()
72 ))
73 } else {
74 None
75 };
76
77 Ok(ConstraintResult {
78 status,
79 metric: Some(column_count),
80 message,
81 })
82 }
83
84 fn name(&self) -> &str {
85 "column_count"
86 }
87
88 fn column(&self) -> Option<&str> {
89 None }
91
92 fn metadata(&self) -> ConstraintMetadata {
93 ConstraintMetadata::default()
94 .with_description(format!(
95 "Checks that the dataset has {} columns",
96 self.assertion.description()
97 ))
98 .with_custom("assertion", self.assertion.description())
99 .with_custom("constraint_type", "schema")
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use crate::core::ConstraintStatus;
107 use arrow::array::Int64Array;
108 use arrow::datatypes::{DataType, Field, Schema};
109 use arrow::record_batch::RecordBatch;
110 use datafusion::datasource::MemTable;
111 use std::sync::Arc;
112
113 async fn create_test_context_with_columns(num_columns: usize) -> SessionContext {
114 let ctx = SessionContext::new();
115
116 let fields: Vec<Field> = (0..num_columns)
118 .map(|i| Field::new(format!("col_{i}"), DataType::Int64, true))
119 .collect();
120
121 let schema = Arc::new(Schema::new(fields));
122
123 let arrays: Vec<Arc<dyn arrow::array::Array>> = (0..num_columns)
125 .map(|_| {
126 Arc::new(Int64Array::from(vec![Some(1), Some(2), Some(3)]))
127 as Arc<dyn arrow::array::Array>
128 })
129 .collect();
130
131 let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
132 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
133 ctx.register_table("data", Arc::new(provider)).unwrap();
134
135 ctx
136 }
137
138 #[tokio::test]
139 async fn test_column_count_equals() {
140 let ctx = create_test_context_with_columns(5).await;
141
142 let constraint = ColumnCountConstraint::new(Assertion::Equals(5.0));
144 let result = constraint.evaluate(&ctx).await.unwrap();
145 assert_eq!(result.status, ConstraintStatus::Success);
146 assert_eq!(result.metric, Some(5.0));
147
148 let constraint = ColumnCountConstraint::new(Assertion::Equals(10.0));
150 let result = constraint.evaluate(&ctx).await.unwrap();
151 assert_eq!(result.status, ConstraintStatus::Failure);
152 assert_eq!(result.metric, Some(5.0));
153 assert!(result.message.is_some());
154 }
155
156 #[tokio::test]
157 async fn test_column_count_greater_than() {
158 let ctx = create_test_context_with_columns(8).await;
159
160 let constraint = ColumnCountConstraint::new(Assertion::GreaterThan(5.0));
162 let result = constraint.evaluate(&ctx).await.unwrap();
163 assert_eq!(result.status, ConstraintStatus::Success);
164
165 let constraint = ColumnCountConstraint::new(Assertion::GreaterThan(10.0));
167 let result = constraint.evaluate(&ctx).await.unwrap();
168 assert_eq!(result.status, ConstraintStatus::Failure);
169 }
170
171 #[tokio::test]
172 async fn test_column_count_less_than() {
173 let ctx = create_test_context_with_columns(3).await;
174
175 let constraint = ColumnCountConstraint::new(Assertion::LessThan(5.0));
177 let result = constraint.evaluate(&ctx).await.unwrap();
178 assert_eq!(result.status, ConstraintStatus::Success);
179
180 let constraint = ColumnCountConstraint::new(Assertion::LessThan(2.0));
182 let result = constraint.evaluate(&ctx).await.unwrap();
183 assert_eq!(result.status, ConstraintStatus::Failure);
184 }
185
186 #[tokio::test]
187 async fn test_column_count_between() {
188 let ctx = create_test_context_with_columns(7).await;
189
190 let constraint = ColumnCountConstraint::new(Assertion::Between(5.0, 10.0));
192 let result = constraint.evaluate(&ctx).await.unwrap();
193 assert_eq!(result.status, ConstraintStatus::Success);
194
195 let constraint = ColumnCountConstraint::new(Assertion::Between(10.0, 15.0));
197 let result = constraint.evaluate(&ctx).await.unwrap();
198 assert_eq!(result.status, ConstraintStatus::Failure);
199 }
200
201 #[tokio::test]
202 async fn test_single_column_dataset() {
203 let ctx = create_test_context_with_columns(1).await;
204
205 let constraint = ColumnCountConstraint::new(Assertion::Equals(1.0));
206 let result = constraint.evaluate(&ctx).await.unwrap();
207 assert_eq!(result.status, ConstraintStatus::Success);
208 assert_eq!(result.metric, Some(1.0));
209 }
210
211 #[tokio::test]
212 async fn test_large_column_count() {
213 let ctx = create_test_context_with_columns(100).await;
214
215 let constraint = ColumnCountConstraint::new(Assertion::GreaterThanOrEqual(100.0));
216 let result = constraint.evaluate(&ctx).await.unwrap();
217 assert_eq!(result.status, ConstraintStatus::Success);
218 assert_eq!(result.metric, Some(100.0));
219 }
220
221 #[tokio::test]
222 async fn test_missing_table() {
223 let ctx = SessionContext::new();
224 let constraint = ColumnCountConstraint::new(Assertion::Equals(5.0));
227 let result = constraint.evaluate(&ctx).await;
228 assert!(result.is_err());
229 }
230
231 #[tokio::test]
232 async fn test_metadata() {
233 let constraint = ColumnCountConstraint::new(Assertion::Between(10.0, 20.0));
234 let metadata = constraint.metadata();
235
236 assert!(metadata
237 .description
238 .unwrap_or_default()
239 .contains("between 10 and 20"));
240 assert_eq!(
241 metadata.custom.get("constraint_type"),
242 Some(&"schema".to_string())
243 );
244 }
245}