term_guard/constraints/
column_count.rs1use crate::constraints::Assertion;
4use crate::core::{Constraint, ConstraintMetadata, ConstraintResult, ConstraintStatus};
5use crate::prelude::*;
6use async_trait::async_trait;
7use datafusion::prelude::*;
8use tracing::instrument;
9#[derive(Debug, Clone)]
28pub struct ColumnCountConstraint {
29 assertion: Assertion,
30}
31
32impl ColumnCountConstraint {
33 pub fn new(assertion: Assertion) -> Self {
39 Self { assertion }
40 }
41}
42
43#[async_trait]
44impl Constraint for ColumnCountConstraint {
45 #[instrument(skip(self, ctx), fields(assertion = ?self.assertion))]
46 async fn evaluate(&self, ctx: &SessionContext) -> Result<ConstraintResult> {
47 let df = ctx.table("data").await.map_err(|e| {
49 TermError::constraint_evaluation(
50 self.name(),
51 format!("Failed to access table 'data': {e}"),
52 )
53 })?;
54
55 let column_count = df.schema().fields().len() as f64;
57
58 let assertion_result = self.assertion.evaluate(column_count);
60
61 let status = if assertion_result {
62 ConstraintStatus::Success
63 } else {
64 ConstraintStatus::Failure
65 };
66
67 let message = if status == ConstraintStatus::Failure {
68 Some(format!(
69 "Column count {column_count} does not satisfy assertion {}",
70 self.assertion.description()
71 ))
72 } else {
73 None
74 };
75
76 Ok(ConstraintResult {
77 status,
78 metric: Some(column_count),
79 message,
80 })
81 }
82
83 fn name(&self) -> &str {
84 "column_count"
85 }
86
87 fn column(&self) -> Option<&str> {
88 None }
90
91 fn metadata(&self) -> ConstraintMetadata {
92 ConstraintMetadata::default()
93 .with_description(format!(
94 "Checks that the dataset has {} columns",
95 self.assertion.description()
96 ))
97 .with_custom("assertion", self.assertion.description())
98 .with_custom("constraint_type", "schema")
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use crate::core::ConstraintStatus;
106 use arrow::array::Int64Array;
107 use arrow::datatypes::{DataType, Field, Schema};
108 use arrow::record_batch::RecordBatch;
109 use datafusion::datasource::MemTable;
110 use std::sync::Arc;
111
112 use crate::test_helpers::evaluate_constraint_with_context;
113 async fn create_test_context_with_columns(num_columns: usize) -> SessionContext {
114 let ctx = SessionContext::new();
115
116 let fields: Vec<Field> = (0..num_columns)
118 .map(|i| Field::new(format!("col_{i}"), DataType::Int64, true))
119 .collect();
120
121 let schema = Arc::new(Schema::new(fields));
122
123 let arrays: Vec<Arc<dyn arrow::array::Array>> = (0..num_columns)
125 .map(|_| {
126 Arc::new(Int64Array::from(vec![Some(1), Some(2), Some(3)]))
127 as Arc<dyn arrow::array::Array>
128 })
129 .collect();
130
131 let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
132 let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
133 ctx.register_table("data", Arc::new(provider)).unwrap();
134
135 ctx
136 }
137
138 #[tokio::test]
139 async fn test_column_count_equals() {
140 let ctx = create_test_context_with_columns(5).await;
141
142 let constraint = ColumnCountConstraint::new(Assertion::Equals(5.0));
144 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
145 .await
146 .unwrap();
147 assert_eq!(result.status, ConstraintStatus::Success);
148 assert_eq!(result.metric, Some(5.0));
149
150 let constraint = ColumnCountConstraint::new(Assertion::Equals(10.0));
152 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
153 .await
154 .unwrap();
155 assert_eq!(result.status, ConstraintStatus::Failure);
156 assert_eq!(result.metric, Some(5.0));
157 assert!(result.message.is_some());
158 }
159
160 #[tokio::test]
161 async fn test_column_count_greater_than() {
162 let ctx = create_test_context_with_columns(8).await;
163
164 let constraint = ColumnCountConstraint::new(Assertion::GreaterThan(5.0));
166 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
167 .await
168 .unwrap();
169 assert_eq!(result.status, ConstraintStatus::Success);
170
171 let constraint = ColumnCountConstraint::new(Assertion::GreaterThan(10.0));
173 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
174 .await
175 .unwrap();
176 assert_eq!(result.status, ConstraintStatus::Failure);
177 }
178
179 #[tokio::test]
180 async fn test_column_count_less_than() {
181 let ctx = create_test_context_with_columns(3).await;
182
183 let constraint = ColumnCountConstraint::new(Assertion::LessThan(5.0));
185 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
186 .await
187 .unwrap();
188 assert_eq!(result.status, ConstraintStatus::Success);
189
190 let constraint = ColumnCountConstraint::new(Assertion::LessThan(2.0));
192 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
193 .await
194 .unwrap();
195 assert_eq!(result.status, ConstraintStatus::Failure);
196 }
197
198 #[tokio::test]
199 async fn test_column_count_between() {
200 let ctx = create_test_context_with_columns(7).await;
201
202 let constraint = ColumnCountConstraint::new(Assertion::Between(5.0, 10.0));
204 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
205 .await
206 .unwrap();
207 assert_eq!(result.status, ConstraintStatus::Success);
208
209 let constraint = ColumnCountConstraint::new(Assertion::Between(10.0, 15.0));
211 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
212 .await
213 .unwrap();
214 assert_eq!(result.status, ConstraintStatus::Failure);
215 }
216
217 #[tokio::test]
218 async fn test_single_column_dataset() {
219 let ctx = create_test_context_with_columns(1).await;
220
221 let constraint = ColumnCountConstraint::new(Assertion::Equals(1.0));
222 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
223 .await
224 .unwrap();
225 assert_eq!(result.status, ConstraintStatus::Success);
226 assert_eq!(result.metric, Some(1.0));
227 }
228
229 #[tokio::test]
230 async fn test_large_column_count() {
231 let ctx = create_test_context_with_columns(100).await;
232
233 let constraint = ColumnCountConstraint::new(Assertion::GreaterThanOrEqual(100.0));
234 let result = evaluate_constraint_with_context(&constraint, &ctx, "data")
235 .await
236 .unwrap();
237 assert_eq!(result.status, ConstraintStatus::Success);
238 assert_eq!(result.metric, Some(100.0));
239 }
240
241 #[tokio::test]
242 async fn test_missing_table() {
243 let ctx = SessionContext::new();
244 let constraint = ColumnCountConstraint::new(Assertion::Equals(5.0));
247 let result = evaluate_constraint_with_context(&constraint, &ctx, "data").await;
248 assert!(result.is_err());
249 }
250
251 #[tokio::test]
252 async fn test_metadata() {
253 let constraint = ColumnCountConstraint::new(Assertion::Between(10.0, 20.0));
254 let metadata = constraint.metadata();
255
256 assert!(metadata
257 .description
258 .unwrap_or_default()
259 .contains("between 10 and 20"));
260 assert_eq!(
261 metadata.custom.get("constraint_type"),
262 Some(&"schema".to_string())
263 );
264 }
265}