vortex_datetime_parts/compute/
compare.rs

1use vortex_array::arrays::ConstantArray;
2use vortex_array::compute::{
3    CompareKernel, CompareKernelAdapter, Operator, and, cast, compare, or,
4};
5use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::datetime::TemporalMetadata;
7use vortex_dtype::{DType, Nullability};
8use vortex_error::{VortexExpect as _, VortexResult};
9use vortex_scalar::Scalar;
10
11use crate::array::{DateTimePartsArray, DateTimePartsVTable};
12use crate::timestamp;
13
14impl CompareKernel for DateTimePartsVTable {
15    /// Compares two arrays and returns a new boolean array with the result of the comparison.
16    /// Or, returns None if comparison is not supported.
17    fn compare(
18        &self,
19        lhs: &DateTimePartsArray,
20        rhs: &dyn Array,
21        operator: Operator,
22    ) -> VortexResult<Option<ArrayRef>> {
23        let Some(rhs_const) = rhs.as_constant() else {
24            return Ok(None);
25        };
26        let Ok(timestamp) = rhs_const
27            .as_extension()
28            .storage()
29            .as_primitive()
30            .as_::<i64>()
31            .map(|maybe_value| maybe_value.vortex_expect("null scalar handled in top-level"))
32        else {
33            return Ok(None);
34        };
35
36        let DType::Extension(ext_dtype) = rhs_const.dtype() else {
37            return Ok(None);
38        };
39
40        let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
41
42        let temporal_metadata = TemporalMetadata::try_from(ext_dtype.as_ref())?;
43        let ts_parts = timestamp::split(timestamp, temporal_metadata.time_unit())?;
44
45        match operator {
46            Operator::Eq => compare_eq(lhs, &ts_parts, nullability),
47            Operator::NotEq => compare_ne(lhs, &ts_parts, nullability),
48            // lt and lte have identical behavior, as we optimize
49            // for the case that all days on the lhs are smaller.
50            //
51            // If that special case is not hit, we return `Ok(None)` to
52            // signal that the comparison wasn't handled within dtp.
53            Operator::Lt => compare_lt(lhs, &ts_parts, nullability),
54            Operator::Lte => compare_lt(lhs, &ts_parts, nullability),
55            // (Like for lt, lte)
56            Operator::Gt => compare_gt(lhs, &ts_parts, nullability),
57            Operator::Gte => compare_gt(lhs, &ts_parts, nullability),
58        }
59    }
60}
61
62register_kernel!(CompareKernelAdapter(DateTimePartsVTable).lift());
63
64fn compare_eq(
65    lhs: &DateTimePartsArray,
66    ts_parts: &timestamp::TimestampParts,
67    nullability: Nullability,
68) -> VortexResult<Option<ArrayRef>> {
69    let mut comparison = compare_dtp(lhs.days(), ts_parts.days, Operator::Eq, nullability)?;
70    if comparison.statistics().compute_max::<bool>() == Some(false) {
71        // All values are different.
72        return Ok(Some(comparison));
73    }
74
75    comparison = and(
76        &compare_dtp(lhs.seconds(), ts_parts.seconds, Operator::Eq, nullability)?,
77        &comparison,
78    )?;
79
80    if comparison.statistics().compute_max::<bool>() == Some(false) {
81        // All values are different.
82        return Ok(Some(comparison));
83    }
84
85    comparison = and(
86        &compare_dtp(
87            lhs.subseconds(),
88            ts_parts.subseconds,
89            Operator::Eq,
90            nullability,
91        )?,
92        &comparison,
93    )?;
94
95    Ok(Some(comparison))
96}
97
98fn compare_ne(
99    lhs: &DateTimePartsArray,
100    ts_parts: &timestamp::TimestampParts,
101    nullability: Nullability,
102) -> VortexResult<Option<ArrayRef>> {
103    let mut comparison = compare_dtp(lhs.days(), ts_parts.days, Operator::NotEq, nullability)?;
104    if comparison.statistics().compute_min::<bool>() == Some(true) {
105        // All values are different.
106        return Ok(Some(comparison));
107    }
108
109    comparison = or(
110        &compare_dtp(
111            lhs.seconds(),
112            ts_parts.seconds,
113            Operator::NotEq,
114            nullability,
115        )?,
116        &comparison,
117    )?;
118
119    if comparison.statistics().compute_min::<bool>() == Some(true) {
120        // All values are different.
121        return Ok(Some(comparison));
122    }
123
124    comparison = or(
125        &compare_dtp(
126            lhs.subseconds(),
127            ts_parts.subseconds,
128            Operator::NotEq,
129            nullability,
130        )?,
131        &comparison,
132    )?;
133
134    Ok(Some(comparison))
135}
136
137fn compare_lt(
138    lhs: &DateTimePartsArray,
139    ts_parts: &timestamp::TimestampParts,
140    nullability: Nullability,
141) -> VortexResult<Option<ArrayRef>> {
142    let days_lt = compare_dtp(lhs.days(), ts_parts.days, Operator::Lt, nullability)?;
143    if days_lt.statistics().compute_min::<bool>() == Some(true) {
144        // All values on the lhs are smaller.
145        return Ok(Some(days_lt));
146    }
147
148    Ok(None)
149}
150
151fn compare_gt(
152    lhs: &DateTimePartsArray,
153    ts_parts: &timestamp::TimestampParts,
154    nullability: Nullability,
155) -> VortexResult<Option<ArrayRef>> {
156    let days_gt = compare_dtp(lhs.days(), ts_parts.days, Operator::Gt, nullability)?;
157    if days_gt.statistics().compute_min::<bool>() == Some(true) {
158        // All values on the lhs are larger.
159        return Ok(Some(days_gt));
160    }
161
162    Ok(None)
163}
164
165fn compare_dtp(
166    lhs: &dyn Array,
167    rhs: i64,
168    operator: Operator,
169    nullability: Nullability,
170) -> VortexResult<ArrayRef> {
171    // Since nullability is stripped from RHS and carried forward through nullability argument we want to incorporate it into lhs.dtype() that we cast rhs into
172    match cast(
173        ConstantArray::new(rhs, lhs.len()).as_ref(),
174        &lhs.dtype().with_nullability(nullability),
175    ) {
176        Ok(casted) => compare(lhs, &casted, operator),
177        // The narrowing cast failed. Therefore, we know lhs < rhs.
178        _ => {
179            let constant_value = match operator {
180                Operator::Eq | Operator::Gte | Operator::Gt => false,
181                Operator::NotEq | Operator::Lte | Operator::Lt => true,
182            };
183            Ok(
184                ConstantArray::new(Scalar::bool(constant_value, nullability), lhs.len())
185                    .into_array(),
186            )
187        }
188    }
189}
190
191#[cfg(test)]
192mod test {
193    use rstest::rstest;
194    use vortex_array::arrays::{PrimitiveArray, TemporalArray};
195    use vortex_array::compute::Operator;
196    use vortex_array::validity::Validity;
197    use vortex_buffer::buffer;
198    use vortex_dtype::NativePType;
199    use vortex_dtype::datetime::TimeUnit;
200
201    use super::*;
202
203    fn dtp_array_from_timestamp<T: NativePType>(
204        value: T,
205        validity: Validity,
206    ) -> DateTimePartsArray {
207        DateTimePartsArray::try_from(TemporalArray::new_timestamp(
208            PrimitiveArray::new(buffer![value], validity).into_array(),
209            TimeUnit::S,
210            Some("UTC".to_string()),
211        ))
212        .expect("Failed to construct DateTimePartsArray from TemporalArray")
213    }
214
215    #[rstest]
216    #[case(Validity::NonNullable, Validity::NonNullable)]
217    #[case(Validity::NonNullable, Validity::AllValid)]
218    #[case(Validity::AllValid, Validity::NonNullable)]
219    #[case(Validity::AllValid, Validity::AllValid)]
220    fn compare_date_time_parts_eq(#[case] lhs_validity: Validity, #[case] rhs_validity: Validity) {
221        let lhs = dtp_array_from_timestamp(86400i64, lhs_validity); // January 2, 1970, 00:00:00 UTC
222        let rhs = dtp_array_from_timestamp(86400i64, rhs_validity.clone()); // January 2, 1970, 00:00:00 UTC
223        let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
224        assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
225
226        let rhs = dtp_array_from_timestamp(0i64, rhs_validity); // January 1, 1970, 00:00:00 UTC
227        let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
228        assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 0);
229    }
230
231    #[rstest]
232    #[case(Validity::NonNullable, Validity::NonNullable)]
233    #[case(Validity::NonNullable, Validity::AllValid)]
234    #[case(Validity::AllValid, Validity::NonNullable)]
235    #[case(Validity::AllValid, Validity::AllValid)]
236    fn compare_date_time_parts_ne(#[case] lhs_validity: Validity, #[case] rhs_validity: Validity) {
237        let lhs = dtp_array_from_timestamp(86400i64, lhs_validity); // January 2, 1970, 00:00:00 UTC
238        let rhs = dtp_array_from_timestamp(86401i64, rhs_validity.clone()); // January 2, 1970, 00:00:01 UTC
239        let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq).unwrap();
240        assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
241
242        let rhs = dtp_array_from_timestamp(86400i64, rhs_validity); // January 2, 1970, 00:00:00 UTC
243        let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq).unwrap();
244        assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 0);
245    }
246
247    #[rstest]
248    #[case(Validity::NonNullable, Validity::NonNullable)]
249    #[case(Validity::NonNullable, Validity::AllValid)]
250    #[case(Validity::AllValid, Validity::NonNullable)]
251    #[case(Validity::AllValid, Validity::AllValid)]
252    fn compare_date_time_parts_lt(#[case] lhs_validity: Validity, #[case] rhs_validity: Validity) {
253        let lhs = dtp_array_from_timestamp(0i64, lhs_validity); // January 1, 1970, 01:00:00 UTC
254        let rhs = dtp_array_from_timestamp(86400i64, rhs_validity); // January 2, 1970, 00:00:00 UTC
255
256        let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Lt).unwrap();
257        assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
258    }
259
260    #[rstest]
261    #[case(Validity::NonNullable, Validity::NonNullable)]
262    #[case(Validity::NonNullable, Validity::AllValid)]
263    #[case(Validity::AllValid, Validity::NonNullable)]
264    #[case(Validity::AllValid, Validity::AllValid)]
265    fn compare_date_time_parts_gt(#[case] lhs_validity: Validity, #[case] rhs_validity: Validity) {
266        let lhs = dtp_array_from_timestamp(86400i64, lhs_validity); // January 2, 1970, 02:00:00 UTC
267        let rhs = dtp_array_from_timestamp(0i64, rhs_validity); // January 1, 1970, 01:00:00 UTC
268
269        let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Gt).unwrap();
270        assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
271    }
272
273    #[rstest]
274    #[case(Validity::NonNullable, Validity::NonNullable)]
275    #[case(Validity::NonNullable, Validity::AllValid)]
276    #[case(Validity::AllValid, Validity::NonNullable)]
277    #[case(Validity::AllValid, Validity::AllValid)]
278    fn compare_date_time_parts_narrowing(
279        #[case] lhs_validity: Validity,
280        #[case] rhs_validity: Validity,
281    ) {
282        let temporal_array = TemporalArray::new_timestamp(
283            PrimitiveArray::new(buffer![0i64], lhs_validity.clone()).into_array(),
284            TimeUnit::S,
285            Some("UTC".to_string()),
286        );
287
288        let lhs = DateTimePartsArray::try_new(
289            DType::Extension(temporal_array.ext_dtype()),
290            PrimitiveArray::new(buffer![0i32], lhs_validity).into_array(),
291            PrimitiveArray::new(buffer![0u32], Validity::NonNullable).into_array(),
292            PrimitiveArray::new(buffer![0i64], Validity::NonNullable).into_array(),
293        )
294        .unwrap();
295
296        // Timestamp with a value larger than i32::MAX.
297        let rhs = dtp_array_from_timestamp(i64::MAX, rhs_validity);
298
299        let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();
300        assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 0);
301
302        let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::NotEq).unwrap();
303        assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
304
305        let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Lt).unwrap();
306        assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
307
308        let comparison = compare(lhs.as_ref(), rhs.as_ref(), Operator::Lte).unwrap();
309        assert_eq!(comparison.as_bool_typed().true_count().unwrap(), 1);
310
311        // `Operator::Gt` and `Operator::Gte` only cover the case of all lhs values
312        // being larger. Therefore, these cases are not covered by unit tests.
313    }
314}