vortex_datetime_parts/
canonical.rs1use vortex_array::arrays::{PrimitiveArray, TemporalArray};
2use vortex_array::compute::cast;
3use vortex_array::validity::Validity;
4use vortex_array::vtable::CanonicalVTable;
5use vortex_array::{Canonical, IntoArray, ToCanonical};
6use vortex_buffer::BufferMut;
7use vortex_dtype::Nullability::NonNullable;
8use vortex_dtype::datetime::{TemporalMetadata, TimeUnit};
9use vortex_dtype::{DType, PType};
10use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
11use vortex_scalar::PrimitiveScalar;
12
13use crate::{DateTimePartsArray, DateTimePartsVTable};
14
15impl CanonicalVTable<DateTimePartsVTable> for DateTimePartsVTable {
16 fn canonicalize(array: &DateTimePartsArray) -> VortexResult<Canonical> {
17 Ok(Canonical::Extension(decode_to_temporal(array)?.into()))
18 }
19}
20
21pub fn decode_to_temporal(array: &DateTimePartsArray) -> VortexResult<TemporalArray> {
25 let DType::Extension(ext) = array.dtype().clone() else {
26 vortex_bail!(ComputeError: "expected dtype to be DType::Extension variant")
27 };
28
29 let Ok(temporal_metadata) = TemporalMetadata::try_from(ext.as_ref()) else {
30 vortex_bail!(ComputeError: "must decode TemporalMetadata from extension metadata");
31 };
32
33 let divisor = match temporal_metadata.time_unit() {
34 TimeUnit::Ns => 1_000_000_000,
35 TimeUnit::Us => 1_000_000,
36 TimeUnit::Ms => 1_000,
37 TimeUnit::S => 1,
38 TimeUnit::D => vortex_bail!(InvalidArgument: "cannot decode into TimeUnit::D"),
39 };
40
41 let days_buf = cast(
42 array.days(),
43 &DType::Primitive(PType::I64, array.dtype().nullability()),
44 )?
45 .to_primitive()?;
46
47 let mut values: BufferMut<i64> = days_buf
52 .into_buffer_mut::<i64>()
53 .map_each(|d| d * 86_400 * divisor);
54
55 if let Some(seconds) = array.seconds().as_constant() {
56 let seconds =
57 PrimitiveScalar::try_from(&seconds.cast(&DType::Primitive(PType::I64, NonNullable))?)?
58 .typed_value::<i64>()
59 .vortex_expect("non-nullable");
60 let seconds = seconds * divisor;
61 for v in values.iter_mut() {
62 *v += seconds;
63 }
64 } else {
65 let seconds_buf =
66 cast(array.seconds(), &DType::Primitive(PType::U32, NonNullable))?.to_primitive()?;
67 for (v, second) in values.iter_mut().zip(seconds_buf.as_slice::<u32>()) {
68 *v += (*second as i64) * divisor;
69 }
70 }
71
72 if let Some(subseconds) = array.subseconds().as_constant() {
73 let subseconds = PrimitiveScalar::try_from(
74 &subseconds.cast(&DType::Primitive(PType::I64, NonNullable))?,
75 )?
76 .typed_value::<i64>()
77 .vortex_expect("non-nullable");
78 for v in values.iter_mut() {
79 *v += subseconds;
80 }
81 } else {
82 let subsecond_buf = cast(
83 array.subseconds(),
84 &DType::Primitive(PType::I64, NonNullable),
85 )?
86 .to_primitive()?;
87 for (v, subseconds) in values.iter_mut().zip(subsecond_buf.as_slice::<i64>()) {
88 *v += *subseconds;
89 }
90 }
91
92 Ok(TemporalArray::new_timestamp(
93 PrimitiveArray::new(values.freeze(), Validity::copy_from_array(array.as_ref())?)
94 .into_array(),
95 temporal_metadata.time_unit(),
96 temporal_metadata.time_zone().map(ToString::to_string),
97 ))
98}
99
100#[cfg(test)]
101mod test {
102
103 use rstest::rstest;
104 use vortex_array::arrays::{PrimitiveArray, TemporalArray};
105 use vortex_array::validity::Validity;
106 use vortex_array::vtable::ValidityHelper;
107 use vortex_array::{IntoArray, ToCanonical};
108 use vortex_buffer::buffer;
109 use vortex_dtype::datetime::TimeUnit;
110
111 use crate::DateTimePartsArray;
112 use crate::canonical::decode_to_temporal;
113
114 #[rstest]
115 #[case(Validity::NonNullable)]
116 #[case(Validity::AllValid)]
117 #[case(Validity::AllInvalid)]
118 #[case(Validity::from_iter([true, false, true]))]
119 fn test_decode_to_temporal(#[case] validity: Validity) {
120 let milliseconds = PrimitiveArray::new(
121 buffer![
122 86_400i64, 86_400i64 + 1000, 86_400i64 + 1000 + 1, ],
126 validity.clone(),
127 );
128 let date_times = DateTimePartsArray::try_from(TemporalArray::new_timestamp(
129 milliseconds.clone().into_array(),
130 TimeUnit::Ms,
131 Some("UTC".to_string()),
132 ))
133 .unwrap();
134
135 assert_eq!(
136 date_times.validity_mask().unwrap(),
137 validity.to_mask(date_times.len()).unwrap()
138 );
139
140 let primitive_values = decode_to_temporal(&date_times)
141 .unwrap()
142 .temporal_values()
143 .to_primitive()
144 .unwrap();
145
146 assert_eq!(
147 primitive_values.as_slice::<i64>(),
148 milliseconds.as_slice::<i64>()
149 );
150 assert_eq!(primitive_values.validity(), &validity);
151 }
152}