vortex_datetime_parts/
canonical.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use num_traits::AsPrimitive;
5use vortex_array::Canonical;
6use vortex_array::IntoArray;
7use vortex_array::ToCanonical;
8use vortex_array::arrays::PrimitiveArray;
9use vortex_array::arrays::TemporalArray;
10use vortex_array::compute::cast;
11use vortex_array::validity::Validity;
12use vortex_array::vtable::CanonicalVTable;
13use vortex_buffer::BufferMut;
14use vortex_dtype::DType;
15use vortex_dtype::PType;
16use vortex_dtype::datetime::TemporalMetadata;
17use vortex_dtype::datetime::TimeUnit;
18use vortex_dtype::match_each_integer_ptype;
19use vortex_error::VortexExpect as _;
20use vortex_error::vortex_panic;
21
22use crate::DateTimePartsArray;
23use crate::DateTimePartsVTable;
24
25impl CanonicalVTable<DateTimePartsVTable> for DateTimePartsVTable {
26    fn canonicalize(array: &DateTimePartsArray) -> Canonical {
27        Canonical::Extension(decode_to_temporal(array).into())
28    }
29}
30
31/// Decode an [Array] into a [TemporalArray].
32///
33/// Enforces that the passed array is actually a [DateTimePartsArray] with proper metadata.
34pub fn decode_to_temporal(array: &DateTimePartsArray) -> TemporalArray {
35    let DType::Extension(ext) = array.dtype().clone() else {
36        vortex_panic!(ComputeError: "expected dtype to be DType::Extension variant")
37    };
38
39    let Ok(temporal_metadata) = TemporalMetadata::try_from(ext.as_ref()) else {
40        vortex_panic!(ComputeError: "must decode TemporalMetadata from extension metadata");
41    };
42
43    let divisor = match temporal_metadata.time_unit() {
44        TimeUnit::Nanoseconds => 1_000_000_000,
45        TimeUnit::Microseconds => 1_000_000,
46        TimeUnit::Milliseconds => 1_000,
47        TimeUnit::Seconds => 1,
48        TimeUnit::Days => vortex_panic!(InvalidArgument: "cannot decode into TimeUnit::D"),
49    };
50
51    let days_buf = cast(
52        array.days(),
53        &DType::Primitive(PType::I64, array.dtype().nullability()),
54    )
55    .vortex_expect("must be able to cast days to i64")
56    .to_primitive();
57
58    // We start with the days component, which is always present.
59    // And then add the seconds and subseconds components.
60    // We split this into separate passes because often the seconds and/org subseconds components
61    // are constant.
62    let mut values: BufferMut<i64> = days_buf
63        .into_buffer_mut::<i64>()
64        .map_each_in_place(|d| d * 86_400 * divisor);
65
66    if let Some(seconds) = array.seconds().as_constant() {
67        let seconds = seconds
68            .as_primitive()
69            .as_::<i64>()
70            .vortex_expect("non-nullable");
71        let seconds = seconds * divisor;
72        for v in values.iter_mut() {
73            *v += seconds;
74        }
75    } else {
76        let seconds_buf = array.seconds().to_primitive();
77        match_each_integer_ptype!(seconds_buf.ptype(), |S| {
78            for (v, second) in values.iter_mut().zip(seconds_buf.as_slice::<S>()) {
79                let second: i64 = second.as_();
80                *v += second * divisor;
81            }
82        });
83    }
84
85    if let Some(subseconds) = array.subseconds().as_constant() {
86        let subseconds = subseconds
87            .as_primitive()
88            .as_::<i64>()
89            .vortex_expect("non-nullable");
90        for v in values.iter_mut() {
91            *v += subseconds;
92        }
93    } else {
94        let subseconds_buf = array.subseconds().to_primitive();
95        match_each_integer_ptype!(subseconds_buf.ptype(), |S| {
96            for (v, subseconds) in values.iter_mut().zip(subseconds_buf.as_slice::<S>()) {
97                let subseconds: i64 = subseconds.as_();
98                *v += subseconds;
99            }
100        });
101    }
102
103    TemporalArray::new_timestamp(
104        PrimitiveArray::new(values.freeze(), Validity::copy_from_array(array.as_ref()))
105            .into_array(),
106        temporal_metadata.time_unit(),
107        temporal_metadata.time_zone().map(ToString::to_string),
108    )
109}
110
111#[cfg(test)]
112mod test {
113    use rstest::rstest;
114    use vortex_array::IntoArray;
115    use vortex_array::ToCanonical;
116    use vortex_array::arrays::PrimitiveArray;
117    use vortex_array::arrays::TemporalArray;
118    use vortex_array::validity::Validity;
119    use vortex_array::vtable::ValidityHelper;
120    use vortex_buffer::buffer;
121    use vortex_dtype::datetime::TimeUnit;
122
123    use crate::DateTimePartsArray;
124    use crate::canonical::decode_to_temporal;
125
126    #[rstest]
127    #[case(Validity::NonNullable)]
128    #[case(Validity::AllValid)]
129    #[case(Validity::AllInvalid)]
130    #[case(Validity::from_iter([true, true, false, false, true, true]))]
131    fn test_decode_to_temporal(#[case] validity: Validity) {
132        let milliseconds = PrimitiveArray::new(
133            buffer![
134                86_400i64, // element with only day component
135                -86_400i64,
136                86_400i64 + 1000, // element with day + second components
137                -86_400i64 - 1000,
138                86_400i64 + 1000 + 1, // element with day + second + sub-second components
139                -86_400i64 - 1000 - 1
140            ],
141            validity.clone(),
142        );
143        let date_times = DateTimePartsArray::try_from(TemporalArray::new_timestamp(
144            milliseconds.clone().into_array(),
145            TimeUnit::Milliseconds,
146            Some("UTC".to_string()),
147        ))
148        .unwrap();
149
150        assert_eq!(
151            date_times.validity_mask(),
152            validity.to_mask(date_times.len())
153        );
154
155        let primitive_values = decode_to_temporal(&date_times)
156            .temporal_values()
157            .to_primitive();
158
159        assert_eq!(
160            primitive_values.as_slice::<i64>(),
161            milliseconds.as_slice::<i64>()
162        );
163        assert_eq!(primitive_values.validity(), &validity);
164    }
165}