vortex_datetime_parts/compute/
compare.rs

1use vortex_array::arrays::ConstantArray;
2use vortex_array::compute::{
3    CompareKernel, CompareKernelAdapter, Operator, and, cast, compare, or,
4};
5use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::DType;
7use vortex_dtype::datetime::TemporalMetadata;
8use vortex_error::{VortexExpect as _, VortexResult};
9
10use crate::array::{DateTimePartsArray, DateTimePartsVTable};
11use crate::timestamp;
12
13impl CompareKernel for DateTimePartsVTable {
14    /// Compares two arrays and returns a new boolean array with the result of the comparison.
15    /// Or, returns None if comparison is not supported.
16    fn compare(
17        &self,
18        lhs: &DateTimePartsArray,
19        rhs: &dyn Array,
20        operator: Operator,
21    ) -> VortexResult<Option<ArrayRef>> {
22        let Some(rhs_const) = rhs.as_constant() else {
23            return Ok(None);
24        };
25        let Ok(timestamp) = rhs_const
26            .as_extension()
27            .storage()
28            .as_primitive()
29            .as_::<i64>()
30            .map(|maybe_value| maybe_value.vortex_expect("null scalar handled in top-level"))
31        else {
32            return Ok(None);
33        };
34
35        let DType::Extension(ext_dtype) = rhs_const.dtype() else {
36            return Ok(None);
37        };
38
39        let temporal_metadata = TemporalMetadata::try_from(ext_dtype.as_ref())?;
40        let ts_parts = timestamp::split(timestamp, temporal_metadata.time_unit())?;
41
42        match operator {
43            Operator::Eq => compare_eq(lhs, &ts_parts),
44            Operator::NotEq => compare_ne(lhs, &ts_parts),
45            // lt and lte have identical behavior, as we optimize
46            // for the case that all days on the lhs are smaller.
47            //
48            // If that special case is not hit, we return `Ok(None)` to
49            // signal that the comparison wasn't handled within dtp.
50            Operator::Lt => compare_lt(lhs, &ts_parts),
51            Operator::Lte => compare_lt(lhs, &ts_parts),
52            // (Like for lt, lte)
53            Operator::Gt => compare_gt(lhs, &ts_parts),
54            Operator::Gte => compare_gt(lhs, &ts_parts),
55        }
56    }
57}
58
59register_kernel!(CompareKernelAdapter(DateTimePartsVTable).lift());
60
61fn compare_eq(
62    lhs: &DateTimePartsArray,
63    ts_parts: &timestamp::TimestampParts,
64) -> VortexResult<Option<ArrayRef>> {
65    let mut comparison = compare_dtp(lhs.days(), ts_parts.days, Operator::Eq)?;
66    if comparison.statistics().compute_max::<bool>() == Some(false) {
67        // All values are different.
68        return Ok(Some(comparison));
69    }
70
71    comparison = and(
72        &compare_dtp(lhs.seconds(), ts_parts.seconds, Operator::Eq)?,
73        &comparison,
74    )?;
75
76    if comparison.statistics().compute_max::<bool>() == Some(false) {
77        // All values are different.
78        return Ok(Some(comparison));
79    }
80
81    comparison = and(
82        &compare_dtp(lhs.subseconds(), ts_parts.subseconds, Operator::Eq)?,
83        &comparison,
84    )?;
85
86    Ok(Some(comparison))
87}
88
89fn compare_ne(
90    lhs: &DateTimePartsArray,
91    ts_parts: &timestamp::TimestampParts,
92) -> VortexResult<Option<ArrayRef>> {
93    let mut comparison = compare_dtp(lhs.days(), ts_parts.days, Operator::NotEq)?;
94    if comparison.statistics().compute_min::<bool>() == Some(true) {
95        // All values are different.
96        return Ok(Some(comparison));
97    }
98
99    comparison = or(
100        &compare_dtp(lhs.seconds(), ts_parts.seconds, Operator::NotEq)?,
101        &comparison,
102    )?;
103
104    if comparison.statistics().compute_min::<bool>() == Some(true) {
105        // All values are different.
106        return Ok(Some(comparison));
107    }
108
109    comparison = or(
110        &compare_dtp(lhs.subseconds(), ts_parts.subseconds, Operator::NotEq)?,
111        &comparison,
112    )?;
113
114    Ok(Some(comparison))
115}
116
117fn compare_lt(
118    lhs: &DateTimePartsArray,
119    ts_parts: &timestamp::TimestampParts,
120) -> VortexResult<Option<ArrayRef>> {
121    let days_lt = compare_dtp(lhs.days(), ts_parts.days, Operator::Lt)?;
122    if days_lt.statistics().compute_min::<bool>() == Some(true) {
123        // All values on the lhs are smaller.
124        return Ok(Some(days_lt));
125    }
126
127    Ok(None)
128}
129
130fn compare_gt(
131    lhs: &DateTimePartsArray,
132    ts_parts: &timestamp::TimestampParts,
133) -> VortexResult<Option<ArrayRef>> {
134    let days_gt = compare_dtp(lhs.days(), ts_parts.days, Operator::Gt)?;
135    if days_gt.statistics().compute_min::<bool>() == Some(true) {
136        // All values on the lhs are larger.
137        return Ok(Some(days_gt));
138    }
139
140    Ok(None)
141}
142
143fn compare_dtp(lhs: &dyn Array, rhs: i64, operator: Operator) -> VortexResult<ArrayRef> {
144    match cast(ConstantArray::new(rhs, lhs.len()).as_ref(), lhs.dtype()) {
145        Ok(casted) => compare(lhs, &casted, operator),
146        // The narrowing cast failed. Therefore, we know lhs < rhs.
147        _ => {
148            let constant_value = match operator {
149                Operator::Eq | Operator::Gte | Operator::Gt => false,
150                Operator::NotEq | Operator::Lte | Operator::Lt => true,
151            };
152            Ok(ConstantArray::new(constant_value, lhs.len()).into_array())
153        }
154    }
155}
156
157#[cfg(test)]
158mod test {
159    use vortex_array::arrays::{PrimitiveArray, TemporalArray};
160    use vortex_array::compute::Operator;
161    use vortex_array::validity::Validity;
162    use vortex_buffer::buffer;
163    use vortex_dtype::NativePType;
164    use vortex_dtype::datetime::TimeUnit;
165
166    use super::*;
167
168    fn dtp_array_from_timestamp<T: NativePType>(value: T) -> DateTimePartsArray {
169        DateTimePartsArray::try_from(TemporalArray::new_timestamp(
170            PrimitiveArray::new(buffer![value], Validity::NonNullable).into_array(),
171            TimeUnit::S,
172            Some("UTC".to_string()),
173        ))
174        .expect("Failed to construct DateTimePartsArray from TemporalArray")
175    }
176
177    #[test]
178    fn compare_date_time_parts_eq() {
179        let lhs = dtp_array_from_timestamp(86400i64); // January 2, 1970, 00:00:00 UTC
180        let rhs = dtp_array_from_timestamp(86400i64); // January 2, 1970, 00:00:00 UTC
181        let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
182        assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
183
184        let rhs = dtp_array_from_timestamp(0i64); // January 1, 1970, 00:00:00 UTC
185        let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
186        assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 0);
187    }
188
189    #[test]
190    fn compare_date_time_parts_ne() {
191        let lhs = dtp_array_from_timestamp(86400i64); // January 2, 1970, 00:00:00 UTC
192        let rhs = dtp_array_from_timestamp(86401i64); // January 2, 1970, 00:00:01 UTC
193        let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq).unwrap();
194        assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
195
196        let rhs = dtp_array_from_timestamp(86400i64); // January 2, 1970, 00:00:00 UTC
197        let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq).unwrap();
198        assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 0);
199    }
200
201    #[test]
202    fn compare_date_time_parts_lt() {
203        let lhs = dtp_array_from_timestamp(0i64); // January 1, 1970, 01:00:00 UTC
204        let rhs = dtp_array_from_timestamp(86400i64); // January 2, 1970, 00:00:00 UTC
205
206        let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Lt).unwrap();
207        assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
208    }
209
210    #[test]
211    fn compare_date_time_parts_gt() {
212        let lhs = dtp_array_from_timestamp(86400i64); // January 2, 1970, 02:00:00 UTC
213        let rhs = dtp_array_from_timestamp(0i64); // January 1, 1970, 01:00:00 UTC
214
215        let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Gt).unwrap();
216        assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
217    }
218
219    #[test]
220    fn compare_date_time_parts_narrowing() {
221        let temporal_array = TemporalArray::new_timestamp(
222            PrimitiveArray::new(buffer![0i64], Validity::NonNullable).into_array(),
223            TimeUnit::S,
224            Some("UTC".to_string()),
225        );
226
227        let lhs = DateTimePartsArray::try_new(
228            DType::Extension(temporal_array.ext_dtype()),
229            PrimitiveArray::new(buffer![0i32], Validity::NonNullable).into_array(),
230            PrimitiveArray::new(buffer![0u32], Validity::NonNullable).into_array(),
231            PrimitiveArray::new(buffer![0i64], Validity::NonNullable).into_array(),
232        )
233        .unwrap();
234
235        // Timestamp with a value larger than i32::MAX.
236        let rhs = dtp_array_from_timestamp(i64::MAX);
237
238        let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
239        assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 0);
240
241        let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq).unwrap();
242        assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
243
244        let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Lt).unwrap();
245        assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
246
247        let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Lte).unwrap();
248        assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
249
250        // `Operator::Gt` and `Operator::Gte` only cover the case of all lhs values
251        // being larger. Therefore, these cases are not covered by unit tests.
252    }
253}