1use vortex_array::arrays::ConstantArray;
2use vortex_array::compute::{
3 CompareKernel, CompareKernelAdapter, Operator, and, cast, compare, or,
4};
5use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::datetime::TemporalMetadata;
7use vortex_dtype::{DType, Nullability};
8use vortex_error::{VortexExpect as _, VortexResult};
9use vortex_scalar::Scalar;
10
11use crate::array::{DateTimePartsArray, DateTimePartsVTable};
12use crate::timestamp;
13
14impl CompareKernel for DateTimePartsVTable {
15 fn compare(
18 &self,
19 lhs: &DateTimePartsArray,
20 rhs: &dyn Array,
21 operator: Operator,
22 ) -> VortexResult<Option<ArrayRef>> {
23 let Some(rhs_const) = rhs.as_constant() else {
24 return Ok(None);
25 };
26 let Ok(timestamp) = rhs_const
27 .as_extension()
28 .storage()
29 .as_primitive()
30 .as_::<i64>()
31 .map(|maybe_value| maybe_value.vortex_expect("null scalar handled in top-level"))
32 else {
33 return Ok(None);
34 };
35
36 let DType::Extension(ext_dtype) = rhs_const.dtype() else {
37 return Ok(None);
38 };
39
40 let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
41
42 let temporal_metadata = TemporalMetadata::try_from(ext_dtype.as_ref())?;
43 let ts_parts = timestamp::split(timestamp, temporal_metadata.time_unit())?;
44
45 match operator {
46 Operator::Eq => compare_eq(lhs, &ts_parts, nullability),
47 Operator::NotEq => compare_ne(lhs, &ts_parts, nullability),
48 Operator::Lt => compare_lt(lhs, &ts_parts, nullability),
54 Operator::Lte => compare_lt(lhs, &ts_parts, nullability),
55 Operator::Gt => compare_gt(lhs, &ts_parts, nullability),
57 Operator::Gte => compare_gt(lhs, &ts_parts, nullability),
58 }
59 }
60}
61
62register_kernel!(CompareKernelAdapter(DateTimePartsVTable).lift());
63
64fn compare_eq(
65 lhs: &DateTimePartsArray,
66 ts_parts: ×tamp::TimestampParts,
67 nullability: Nullability,
68) -> VortexResult<Option<ArrayRef>> {
69 let mut comparison = compare_dtp(lhs.days(), ts_parts.days, Operator::Eq, nullability)?;
70 if comparison.statistics().compute_max::<bool>() == Some(false) {
71 return Ok(Some(comparison));
73 }
74
75 comparison = and(
76 &compare_dtp(lhs.seconds(), ts_parts.seconds, Operator::Eq, nullability)?,
77 &comparison,
78 )?;
79
80 if comparison.statistics().compute_max::<bool>() == Some(false) {
81 return Ok(Some(comparison));
83 }
84
85 comparison = and(
86 &compare_dtp(
87 lhs.subseconds(),
88 ts_parts.subseconds,
89 Operator::Eq,
90 nullability,
91 )?,
92 &comparison,
93 )?;
94
95 Ok(Some(comparison))
96}
97
98fn compare_ne(
99 lhs: &DateTimePartsArray,
100 ts_parts: ×tamp::TimestampParts,
101 nullability: Nullability,
102) -> VortexResult<Option<ArrayRef>> {
103 let mut comparison = compare_dtp(lhs.days(), ts_parts.days, Operator::NotEq, nullability)?;
104 if comparison.statistics().compute_min::<bool>() == Some(true) {
105 return Ok(Some(comparison));
107 }
108
109 comparison = or(
110 &compare_dtp(
111 lhs.seconds(),
112 ts_parts.seconds,
113 Operator::NotEq,
114 nullability,
115 )?,
116 &comparison,
117 )?;
118
119 if comparison.statistics().compute_min::<bool>() == Some(true) {
120 return Ok(Some(comparison));
122 }
123
124 comparison = or(
125 &compare_dtp(
126 lhs.subseconds(),
127 ts_parts.subseconds,
128 Operator::NotEq,
129 nullability,
130 )?,
131 &comparison,
132 )?;
133
134 Ok(Some(comparison))
135}
136
137fn compare_lt(
138 lhs: &DateTimePartsArray,
139 ts_parts: ×tamp::TimestampParts,
140 nullability: Nullability,
141) -> VortexResult<Option<ArrayRef>> {
142 let days_lt = compare_dtp(lhs.days(), ts_parts.days, Operator::Lt, nullability)?;
143 if days_lt.statistics().compute_min::<bool>() == Some(true) {
144 return Ok(Some(days_lt));
146 }
147
148 Ok(None)
149}
150
151fn compare_gt(
152 lhs: &DateTimePartsArray,
153 ts_parts: ×tamp::TimestampParts,
154 nullability: Nullability,
155) -> VortexResult<Option<ArrayRef>> {
156 let days_gt = compare_dtp(lhs.days(), ts_parts.days, Operator::Gt, nullability)?;
157 if days_gt.statistics().compute_min::<bool>() == Some(true) {
158 return Ok(Some(days_gt));
160 }
161
162 Ok(None)
163}
164
165fn compare_dtp(
166 lhs: &dyn Array,
167 rhs: i64,
168 operator: Operator,
169 nullability: Nullability,
170) -> VortexResult<ArrayRef> {
171 match cast(
173 ConstantArray::new(rhs, lhs.len()).as_ref(),
174 &lhs.dtype().with_nullability(nullability),
175 ) {
176 Ok(casted) => compare(lhs, &casted, operator),
177 _ => {
179 let constant_value = match operator {
180 Operator::Eq | Operator::Gte | Operator::Gt => false,
181 Operator::NotEq | Operator::Lte | Operator::Lt => true,
182 };
183 Ok(
184 ConstantArray::new(Scalar::bool(constant_value, nullability), lhs.len())
185 .into_array(),
186 )
187 }
188 }
189}
190
191#[cfg(test)]
192mod test {
193 use rstest::rstest;
194 use vortex_array::arrays::{PrimitiveArray, TemporalArray};
195 use vortex_array::compute::Operator;
196 use vortex_array::validity::Validity;
197 use vortex_buffer::buffer;
198 use vortex_dtype::NativePType;
199 use vortex_dtype::datetime::TimeUnit;
200
201 use super::*;
202
203 fn dtp_array_from_timestamp<T: NativePType>(
204 value: T,
205 validity: Validity,
206 ) -> DateTimePartsArray {
207 DateTimePartsArray::try_from(TemporalArray::new_timestamp(
208 PrimitiveArray::new(buffer![value], validity).into_array(),
209 TimeUnit::S,
210 Some("UTC".to_string()),
211 ))
212 .expect("Failed to construct DateTimePartsArray from TemporalArray")
213 }
214
215 #[rstest]
216 #[case(Validity::NonNullable, Validity::NonNullable)]
217 #[case(Validity::NonNullable, Validity::AllValid)]
218 #[case(Validity::AllValid, Validity::NonNullable)]
219 #[case(Validity::AllValid, Validity::AllValid)]
220 fn compare_date_time_parts_eq(#[case] lhs_validity: Validity, #[case] rhs_validity: Validity) {
221 let lhs = dtp_array_from_timestamp(86400i64, lhs_validity); let rhs = dtp_array_from_timestamp(86400i64, rhs_validity.clone()); let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
224 assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
225
226 let rhs = dtp_array_from_timestamp(0i64, rhs_validity); let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
228 assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 0);
229 }
230
231 #[rstest]
232 #[case(Validity::NonNullable, Validity::NonNullable)]
233 #[case(Validity::NonNullable, Validity::AllValid)]
234 #[case(Validity::AllValid, Validity::NonNullable)]
235 #[case(Validity::AllValid, Validity::AllValid)]
236 fn compare_date_time_parts_ne(#[case] lhs_validity: Validity, #[case] rhs_validity: Validity) {
237 let lhs = dtp_array_from_timestamp(86400i64, lhs_validity); let rhs = dtp_array_from_timestamp(86401i64, rhs_validity.clone()); let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq).unwrap();
240 assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
241
242 let rhs = dtp_array_from_timestamp(86400i64, rhs_validity); let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq).unwrap();
244 assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 0);
245 }
246
247 #[rstest]
248 #[case(Validity::NonNullable, Validity::NonNullable)]
249 #[case(Validity::NonNullable, Validity::AllValid)]
250 #[case(Validity::AllValid, Validity::NonNullable)]
251 #[case(Validity::AllValid, Validity::AllValid)]
252 fn compare_date_time_parts_lt(#[case] lhs_validity: Validity, #[case] rhs_validity: Validity) {
253 let lhs = dtp_array_from_timestamp(0i64, lhs_validity); let rhs = dtp_array_from_timestamp(86400i64, rhs_validity); let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Lt).unwrap();
257 assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
258 }
259
260 #[rstest]
261 #[case(Validity::NonNullable, Validity::NonNullable)]
262 #[case(Validity::NonNullable, Validity::AllValid)]
263 #[case(Validity::AllValid, Validity::NonNullable)]
264 #[case(Validity::AllValid, Validity::AllValid)]
265 fn compare_date_time_parts_gt(#[case] lhs_validity: Validity, #[case] rhs_validity: Validity) {
266 let lhs = dtp_array_from_timestamp(86400i64, lhs_validity); let rhs = dtp_array_from_timestamp(0i64, rhs_validity); let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Gt).unwrap();
270 assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
271 }
272
273 #[rstest]
274 #[case(Validity::NonNullable, Validity::NonNullable)]
275 #[case(Validity::NonNullable, Validity::AllValid)]
276 #[case(Validity::AllValid, Validity::NonNullable)]
277 #[case(Validity::AllValid, Validity::AllValid)]
278 fn compare_date_time_parts_narrowing(
279 #[case] lhs_validity: Validity,
280 #[case] rhs_validity: Validity,
281 ) {
282 let temporal_array = TemporalArray::new_timestamp(
283 PrimitiveArray::new(buffer![0i64], lhs_validity.clone()).into_array(),
284 TimeUnit::S,
285 Some("UTC".to_string()),
286 );
287
288 let lhs = DateTimePartsArray::try_new(
289 DType::Extension(temporal_array.ext_dtype()),
290 PrimitiveArray::new(buffer![0i32], lhs_validity).into_array(),
291 PrimitiveArray::new(buffer![0u32], Validity::NonNullable).into_array(),
292 PrimitiveArray::new(buffer![0i64], Validity::NonNullable).into_array(),
293 )
294 .unwrap();
295
296 let rhs = dtp_array_from_timestamp(i64::MAX, rhs_validity);
298
299 let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
300 assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 0);
301
302 let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq).unwrap();
303 assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
304
305 let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Lt).unwrap();
306 assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
307
308 let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Lte).unwrap();
309 assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
310
311 }
314}