vortex_datetime_parts/compute/
compare.rs1use vortex_array::arrays::ConstantArray;
2use vortex_array::compute::{
3 CompareKernel, CompareKernelAdapter, Operator, and, cast, compare, or,
4};
5use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::DType;
7use vortex_dtype::datetime::TemporalMetadata;
8use vortex_error::{VortexExpect as _, VortexResult};
9
10use crate::array::{DateTimePartsArray, DateTimePartsVTable};
11use crate::timestamp;
12
13impl CompareKernel for DateTimePartsVTable {
14 fn compare(
17 &self,
18 lhs: &DateTimePartsArray,
19 rhs: &dyn Array,
20 operator: Operator,
21 ) -> VortexResult<Option<ArrayRef>> {
22 let Some(rhs_const) = rhs.as_constant() else {
23 return Ok(None);
24 };
25 let Ok(timestamp) = rhs_const
26 .as_extension()
27 .storage()
28 .as_primitive()
29 .as_::<i64>()
30 .map(|maybe_value| maybe_value.vortex_expect("null scalar handled in top-level"))
31 else {
32 return Ok(None);
33 };
34
35 let DType::Extension(ext_dtype) = rhs_const.dtype() else {
36 return Ok(None);
37 };
38
39 let temporal_metadata = TemporalMetadata::try_from(ext_dtype.as_ref())?;
40 let ts_parts = timestamp::split(timestamp, temporal_metadata.time_unit())?;
41
42 match operator {
43 Operator::Eq => compare_eq(lhs, &ts_parts),
44 Operator::NotEq => compare_ne(lhs, &ts_parts),
45 Operator::Lt => compare_lt(lhs, &ts_parts),
51 Operator::Lte => compare_lt(lhs, &ts_parts),
52 Operator::Gt => compare_gt(lhs, &ts_parts),
54 Operator::Gte => compare_gt(lhs, &ts_parts),
55 }
56 }
57}
58
59register_kernel!(CompareKernelAdapter(DateTimePartsVTable).lift());
60
61fn compare_eq(
62 lhs: &DateTimePartsArray,
63 ts_parts: ×tamp::TimestampParts,
64) -> VortexResult<Option<ArrayRef>> {
65 let mut comparison = compare_dtp(lhs.days(), ts_parts.days, Operator::Eq)?;
66 if comparison.statistics().compute_max::<bool>() == Some(false) {
67 return Ok(Some(comparison));
69 }
70
71 comparison = and(
72 &compare_dtp(lhs.seconds(), ts_parts.seconds, Operator::Eq)?,
73 &comparison,
74 )?;
75
76 if comparison.statistics().compute_max::<bool>() == Some(false) {
77 return Ok(Some(comparison));
79 }
80
81 comparison = and(
82 &compare_dtp(lhs.subseconds(), ts_parts.subseconds, Operator::Eq)?,
83 &comparison,
84 )?;
85
86 Ok(Some(comparison))
87}
88
89fn compare_ne(
90 lhs: &DateTimePartsArray,
91 ts_parts: ×tamp::TimestampParts,
92) -> VortexResult<Option<ArrayRef>> {
93 let mut comparison = compare_dtp(lhs.days(), ts_parts.days, Operator::NotEq)?;
94 if comparison.statistics().compute_min::<bool>() == Some(true) {
95 return Ok(Some(comparison));
97 }
98
99 comparison = or(
100 &compare_dtp(lhs.seconds(), ts_parts.seconds, Operator::NotEq)?,
101 &comparison,
102 )?;
103
104 if comparison.statistics().compute_min::<bool>() == Some(true) {
105 return Ok(Some(comparison));
107 }
108
109 comparison = or(
110 &compare_dtp(lhs.subseconds(), ts_parts.subseconds, Operator::NotEq)?,
111 &comparison,
112 )?;
113
114 Ok(Some(comparison))
115}
116
117fn compare_lt(
118 lhs: &DateTimePartsArray,
119 ts_parts: ×tamp::TimestampParts,
120) -> VortexResult<Option<ArrayRef>> {
121 let days_lt = compare_dtp(lhs.days(), ts_parts.days, Operator::Lt)?;
122 if days_lt.statistics().compute_min::<bool>() == Some(true) {
123 return Ok(Some(days_lt));
125 }
126
127 Ok(None)
128}
129
130fn compare_gt(
131 lhs: &DateTimePartsArray,
132 ts_parts: ×tamp::TimestampParts,
133) -> VortexResult<Option<ArrayRef>> {
134 let days_gt = compare_dtp(lhs.days(), ts_parts.days, Operator::Gt)?;
135 if days_gt.statistics().compute_min::<bool>() == Some(true) {
136 return Ok(Some(days_gt));
138 }
139
140 Ok(None)
141}
142
143fn compare_dtp(lhs: &dyn Array, rhs: i64, operator: Operator) -> VortexResult<ArrayRef> {
144 match cast(ConstantArray::new(rhs, lhs.len()).as_ref(), lhs.dtype()) {
145 Ok(casted) => compare(lhs, &casted, operator),
146 _ => {
148 let constant_value = match operator {
149 Operator::Eq | Operator::Gte | Operator::Gt => false,
150 Operator::NotEq | Operator::Lte | Operator::Lt => true,
151 };
152 Ok(ConstantArray::new(constant_value, lhs.len()).into_array())
153 }
154 }
155}
156
157#[cfg(test)]
158mod test {
159 use vortex_array::arrays::{PrimitiveArray, TemporalArray};
160 use vortex_array::compute::Operator;
161 use vortex_array::validity::Validity;
162 use vortex_buffer::buffer;
163 use vortex_dtype::NativePType;
164 use vortex_dtype::datetime::TimeUnit;
165
166 use super::*;
167
168 fn dtp_array_from_timestamp<T: NativePType>(value: T) -> DateTimePartsArray {
169 DateTimePartsArray::try_from(TemporalArray::new_timestamp(
170 PrimitiveArray::new(buffer![value], Validity::NonNullable).into_array(),
171 TimeUnit::S,
172 Some("UTC".to_string()),
173 ))
174 .expect("Failed to construct DateTimePartsArray from TemporalArray")
175 }
176
177 #[test]
178 fn compare_date_time_parts_eq() {
179 let lhs = dtp_array_from_timestamp(86400i64); let rhs = dtp_array_from_timestamp(86400i64); let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
182 assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
183
184 let rhs = dtp_array_from_timestamp(0i64); let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
186 assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 0);
187 }
188
189 #[test]
190 fn compare_date_time_parts_ne() {
191 let lhs = dtp_array_from_timestamp(86400i64); let rhs = dtp_array_from_timestamp(86401i64); let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq).unwrap();
194 assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
195
196 let rhs = dtp_array_from_timestamp(86400i64); let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq).unwrap();
198 assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 0);
199 }
200
201 #[test]
202 fn compare_date_time_parts_lt() {
203 let lhs = dtp_array_from_timestamp(0i64); let rhs = dtp_array_from_timestamp(86400i64); let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Lt).unwrap();
207 assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
208 }
209
210 #[test]
211 fn compare_date_time_parts_gt() {
212 let lhs = dtp_array_from_timestamp(86400i64); let rhs = dtp_array_from_timestamp(0i64); let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Gt).unwrap();
216 assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
217 }
218
219 #[test]
220 fn compare_date_time_parts_narrowing() {
221 let temporal_array = TemporalArray::new_timestamp(
222 PrimitiveArray::new(buffer![0i64], Validity::NonNullable).into_array(),
223 TimeUnit::S,
224 Some("UTC".to_string()),
225 );
226
227 let lhs = DateTimePartsArray::try_new(
228 DType::Extension(temporal_array.ext_dtype()),
229 PrimitiveArray::new(buffer![0i32], Validity::NonNullable).into_array(),
230 PrimitiveArray::new(buffer![0u32], Validity::NonNullable).into_array(),
231 PrimitiveArray::new(buffer![0i64], Validity::NonNullable).into_array(),
232 )
233 .unwrap();
234
235 let rhs = dtp_array_from_timestamp(i64::MAX);
237
238 let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
239 assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 0);
240
241 let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq).unwrap();
242 assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
243
244 let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Lt).unwrap();
245 assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
246
247 let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Lte).unwrap();
248 assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
249
250 }
253}