Skip to main content

vortex_pco/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::IntoArray;
7use vortex_array::dtype::DType;
8use vortex_array::scalar_fn::fns::cast::CastReduce;
9use vortex_array::vtable::child_to_validity;
10use vortex_error::VortexResult;
11
12use crate::Pco;
13use crate::PcoData;
14
15impl CastReduce for Pco {
16    fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17        // PCO (Pcodec) stores compressed data and uses validity bits to decode (the validity
18        // tells PCO which logical positions correspond to compressed values). Casting away
19        // nullability would change the validity-to-compressed-value mapping, so we cannot
20        // construct a non-nullable Pco without re-encoding — we only handle nullability changes
21        // toward `Nullable`. Non-nullable targets fall through to canonicalization.
22        //
23        // No `CastKernel` is provided for the same reason: even with execution context, we
24        // cannot cast away nullability on a PCO array in place.
25        //
26        // PCO supports: F16, F32, F64, I16, I32, I64, U16, U32, U64.
27        if !array.dtype().eq_ignore_nullability(dtype) {
28            return Ok(None);
29        }
30
31        let unsliced_validity =
32            child_to_validity(array.slots()[0].as_ref(), array.dtype().nullability());
33        let Some(new_validity) =
34            unsliced_validity.trivial_cast_nullability(dtype.nullability(), array.len())?
35        else {
36            return Ok(None);
37        };
38
39        let data = PcoData::new(
40            array.chunk_metas.clone(),
41            array.pages.clone(),
42            dtype.as_ptype(),
43            array.metadata.clone(),
44            array.unsliced_n_rows(),
45        )
46        ._slice(array.slice_start(), array.slice_stop());
47
48        Ok(Some(
49            Pco::try_new(dtype.clone(), data, new_validity)?.into_array(),
50        ))
51    }
52}
53
54#[cfg(test)]
55mod tests {
56    use std::sync::LazyLock;
57
58    use rstest::rstest;
59    use vortex_array::IntoArray;
60    use vortex_array::VortexSessionExecute;
61    use vortex_array::arrays::PrimitiveArray;
62    use vortex_array::assert_arrays_eq;
63    use vortex_array::builtins::ArrayBuiltins;
64    use vortex_array::compute::conformance::cast::test_cast_conformance;
65    use vortex_array::dtype::DType;
66    use vortex_array::dtype::Nullability;
67    use vortex_array::dtype::PType;
68    use vortex_array::session::ArraySession;
69    use vortex_array::validity::Validity;
70    use vortex_buffer::buffer;
71    use vortex_session::VortexSession;
72
73    use crate::Pco;
74
75    static SESSION: LazyLock<VortexSession> =
76        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
77
78    #[test]
79    fn test_cast_pco_f32_to_f64() {
80        let mut ctx = SESSION.create_execution_ctx();
81        let values = PrimitiveArray::from_iter([1.0f32, 2.0, 3.0, 4.0, 5.0]);
82        let pco = Pco::from_primitive(values.as_view(), 0, 128, &mut ctx).unwrap();
83
84        let casted = pco
85            .into_array()
86            .cast(DType::Primitive(PType::F64, Nullability::NonNullable))
87            .unwrap();
88        assert_eq!(
89            casted.dtype(),
90            &DType::Primitive(PType::F64, Nullability::NonNullable)
91        );
92
93        assert_arrays_eq!(
94            casted,
95            PrimitiveArray::from_iter([1.0f64, 2.0, 3.0, 4.0, 5.0])
96        );
97    }
98
99    #[test]
100    fn test_cast_pco_nullability_change() {
101        let mut ctx = SESSION.create_execution_ctx();
102        // Test casting from NonNullable to Nullable
103        let values = PrimitiveArray::from_iter([10u32, 20, 30, 40]);
104        let pco = Pco::from_primitive(values.as_view(), 0, 128, &mut ctx).unwrap();
105
106        let casted = pco
107            .into_array()
108            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
109            .unwrap();
110        assert_arrays_eq!(
111            casted,
112            PrimitiveArray::new(buffer![10u32, 20, 30, 40], Validity::AllValid,)
113        );
114    }
115
116    #[test]
117    fn test_cast_sliced_pco_nullable_to_nonnullable() {
118        let mut ctx = SESSION.create_execution_ctx();
119        let values = PrimitiveArray::new(
120            buffer![10u32, 20, 30, 40, 50, 60],
121            Validity::from_iter([true, true, true, true, true, true]),
122        );
123        let pco = Pco::from_primitive(values.as_view(), 0, 128, &mut ctx).unwrap();
124        let sliced = pco.slice(1..5).unwrap();
125        let casted = sliced
126            .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
127            .unwrap();
128        assert_eq!(
129            casted.dtype(),
130            &DType::Primitive(PType::U32, Nullability::NonNullable)
131        );
132        // Verify the values are correct
133        assert_arrays_eq!(casted, PrimitiveArray::from_iter([20u32, 30, 40, 50]));
134    }
135
136    #[test]
137    fn test_cast_sliced_pco_part_valid_to_nonnullable() {
138        let mut ctx = SESSION.create_execution_ctx();
139        let values = PrimitiveArray::from_option_iter([
140            None,
141            Some(20u32),
142            Some(30),
143            Some(40),
144            Some(50),
145            Some(60),
146        ]);
147        let pco = Pco::from_primitive(values.as_view(), 0, 128, &mut ctx).unwrap();
148        let sliced = pco.slice(1..5).unwrap();
149        let casted = sliced
150            .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
151            .unwrap();
152        assert_eq!(
153            casted.dtype(),
154            &DType::Primitive(PType::U32, Nullability::NonNullable)
155        );
156        assert_arrays_eq!(casted, PrimitiveArray::from_iter([20u32, 30, 40, 50]));
157    }
158
159    #[rstest]
160    #[case::f32(PrimitiveArray::new(
161        buffer![1.23f32, 4.56, 7.89, 10.11, 12.13],
162        Validity::NonNullable,
163    ))]
164    #[case::f64(PrimitiveArray::new(
165        buffer![100.1f64, 200.2, 300.3, 400.4, 500.5],
166        Validity::NonNullable,
167    ))]
168    #[case::i32(PrimitiveArray::new(
169        buffer![100i32, 200, 300, 400, 500],
170        Validity::NonNullable,
171    ))]
172    #[case::u64(PrimitiveArray::new(
173        buffer![1000u64, 2000, 3000, 4000],
174        Validity::NonNullable,
175    ))]
176    #[case::single(PrimitiveArray::new(
177        buffer![42.42f64],
178        Validity::NonNullable,
179    ))]
180    fn test_cast_pco_conformance(#[case] values: PrimitiveArray) {
181        let mut ctx = SESSION.create_execution_ctx();
182        let pco = Pco::from_primitive(values.as_view(), 0, 128, &mut ctx).unwrap();
183        test_cast_conformance(&pco.into_array());
184    }
185}