vortex_datetime_parts/
canonical.rs

1use vortex_array::arrays::{PrimitiveArray, TemporalArray};
2use vortex_array::compute::cast;
3use vortex_array::validity::Validity;
4use vortex_array::vtable::CanonicalVTable;
5use vortex_array::{Canonical, IntoArray, ToCanonical};
6use vortex_buffer::BufferMut;
7use vortex_dtype::Nullability::NonNullable;
8use vortex_dtype::datetime::{TemporalMetadata, TimeUnit};
9use vortex_dtype::{DType, PType};
10use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
11use vortex_scalar::PrimitiveScalar;
12
13use crate::{DateTimePartsArray, DateTimePartsVTable};
14
15impl CanonicalVTable<DateTimePartsVTable> for DateTimePartsVTable {
16    fn canonicalize(array: &DateTimePartsArray) -> VortexResult<Canonical> {
17        Ok(Canonical::Extension(decode_to_temporal(array)?.into()))
18    }
19}
20
21/// Decode an [Array] into a [TemporalArray].
22///
23/// Enforces that the passed array is actually a [DateTimePartsArray] with proper metadata.
24pub fn decode_to_temporal(array: &DateTimePartsArray) -> VortexResult<TemporalArray> {
25    let DType::Extension(ext) = array.dtype().clone() else {
26        vortex_bail!(ComputeError: "expected dtype to be DType::Extension variant")
27    };
28
29    let Ok(temporal_metadata) = TemporalMetadata::try_from(ext.as_ref()) else {
30        vortex_bail!(ComputeError: "must decode TemporalMetadata from extension metadata");
31    };
32
33    let divisor = match temporal_metadata.time_unit() {
34        TimeUnit::Ns => 1_000_000_000,
35        TimeUnit::Us => 1_000_000,
36        TimeUnit::Ms => 1_000,
37        TimeUnit::S => 1,
38        TimeUnit::D => vortex_bail!(InvalidArgument: "cannot decode into TimeUnit::D"),
39    };
40
41    let days_buf = cast(
42        array.days(),
43        &DType::Primitive(PType::I64, array.dtype().nullability()),
44    )?
45    .to_primitive()?;
46
47    // We start with the days component, which is always present.
48    // And then add the seconds and subseconds components.
49    // We split this into separate passes because often the seconds and/org subseconds components
50    // are constant.
51    let mut values: BufferMut<i64> = days_buf
52        .into_buffer_mut::<i64>()
53        .map_each(|d| d * 86_400 * divisor);
54
55    if let Some(seconds) = array.seconds().as_constant() {
56        let seconds =
57            PrimitiveScalar::try_from(&seconds.cast(&DType::Primitive(PType::I64, NonNullable))?)?
58                .typed_value::<i64>()
59                .vortex_expect("non-nullable");
60        let seconds = seconds * divisor;
61        for v in values.iter_mut() {
62            *v += seconds;
63        }
64    } else {
65        let seconds_buf =
66            cast(array.seconds(), &DType::Primitive(PType::U32, NonNullable))?.to_primitive()?;
67        for (v, second) in values.iter_mut().zip(seconds_buf.as_slice::<u32>()) {
68            *v += (*second as i64) * divisor;
69        }
70    }
71
72    if let Some(subseconds) = array.subseconds().as_constant() {
73        let subseconds = PrimitiveScalar::try_from(
74            &subseconds.cast(&DType::Primitive(PType::I64, NonNullable))?,
75        )?
76        .typed_value::<i64>()
77        .vortex_expect("non-nullable");
78        for v in values.iter_mut() {
79            *v += subseconds;
80        }
81    } else {
82        let subsecond_buf = cast(
83            array.subseconds(),
84            &DType::Primitive(PType::I64, NonNullable),
85        )?
86        .to_primitive()?;
87        for (v, subseconds) in values.iter_mut().zip(subsecond_buf.as_slice::<i64>()) {
88            *v += *subseconds;
89        }
90    }
91
92    Ok(TemporalArray::new_timestamp(
93        PrimitiveArray::new(values.freeze(), Validity::copy_from_array(array.as_ref())?)
94            .into_array(),
95        temporal_metadata.time_unit(),
96        temporal_metadata.time_zone().map(ToString::to_string),
97    ))
98}
99
100#[cfg(test)]
101mod test {
102
103    use rstest::rstest;
104    use vortex_array::arrays::{PrimitiveArray, TemporalArray};
105    use vortex_array::validity::Validity;
106    use vortex_array::vtable::ValidityHelper;
107    use vortex_array::{IntoArray, ToCanonical};
108    use vortex_buffer::buffer;
109    use vortex_dtype::datetime::TimeUnit;
110
111    use crate::DateTimePartsArray;
112    use crate::canonical::decode_to_temporal;
113
114    #[rstest]
115    #[case(Validity::NonNullable)]
116    #[case(Validity::AllValid)]
117    #[case(Validity::AllInvalid)]
118    #[case(Validity::from_iter([true, false, true]))]
119    fn test_decode_to_temporal(#[case] validity: Validity) {
120        let milliseconds = PrimitiveArray::new(
121            buffer![
122                86_400i64,            // element with only day component
123                86_400i64 + 1000,     // element with day + second components
124                86_400i64 + 1000 + 1, // element with day + second + sub-second components
125            ],
126            validity.clone(),
127        );
128        let date_times = DateTimePartsArray::try_from(TemporalArray::new_timestamp(
129            milliseconds.clone().into_array(),
130            TimeUnit::Ms,
131            Some("UTC".to_string()),
132        ))
133        .unwrap();
134
135        assert_eq!(
136            date_times.validity_mask().unwrap(),
137            validity.to_mask(date_times.len()).unwrap()
138        );
139
140        let primitive_values = decode_to_temporal(&date_times)
141            .unwrap()
142            .temporal_values()
143            .to_primitive()
144            .unwrap();
145
146        assert_eq!(
147            primitive_values.as_slice::<i64>(),
148            milliseconds.as_slice::<i64>()
149        );
150        assert_eq!(primitive_values.validity(), &validity);
151    }
152}