Skip to main content

vortex_pco/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::IntoArray;
7use vortex_array::dtype::DType;
8use vortex_array::scalar_fn::fns::cast::CastReduce;
9use vortex_array::vtable::child_to_validity;
10use vortex_error::VortexResult;
11
12use crate::Pco;
13use crate::PcoData;
14impl CastReduce for Pco {
15    fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
16        if !dtype.is_nullable() || !array.array().all_valid()? {
17            // TODO(joe): fixme
18            // We cannot cast to non-nullable since the validity containing nulls is used to decode
19            // the PCO array, this would require rewriting tables.
20            return Ok(None);
21        }
22        // PCO (Pcodec) is a compression encoding that stores data in a compressed format.
23        // It can efficiently handle nullability changes without decompression, but type changes
24        // require decompression since the compression algorithm is type-specific.
25        // PCO supports: F16, F32, F64, I16, I32, I64, U16, U32, U64
26        if array.dtype().eq_ignore_nullability(dtype) {
27            // Create a new validity with the target nullability
28            let unsliced_validity =
29                child_to_validity(&array.slots()[0], array.dtype().nullability());
30            let new_validity =
31                unsliced_validity.cast_nullability(dtype.nullability(), array.len())?;
32
33            let data = PcoData::new(
34                array.chunk_metas.clone(),
35                array.pages.clone(),
36                dtype.as_ptype(),
37                array.metadata.clone(),
38                array.unsliced_n_rows(),
39            )
40            ._slice(array.slice_start(), array.slice_stop());
41
42            return Ok(Some(
43                Pco::try_new(dtype.clone(), data, new_validity)?.into_array(),
44            ));
45        }
46
47        // For other casts (e.g., numeric type changes), decode to canonical and let PrimitiveArray handle it
48        Ok(None)
49    }
50}
51
52#[cfg(test)]
53mod tests {
54    use rstest::rstest;
55    use vortex_array::IntoArray;
56    use vortex_array::arrays::PrimitiveArray;
57    use vortex_array::assert_arrays_eq;
58    use vortex_array::builtins::ArrayBuiltins;
59    use vortex_array::compute::conformance::cast::test_cast_conformance;
60    use vortex_array::dtype::DType;
61    use vortex_array::dtype::Nullability;
62    use vortex_array::dtype::PType;
63    use vortex_array::validity::Validity;
64    use vortex_buffer::buffer;
65
66    use crate::Pco;
67
68    #[test]
69    fn test_cast_pco_f32_to_f64() {
70        let values = PrimitiveArray::from_iter([1.0f32, 2.0, 3.0, 4.0, 5.0]);
71        let pco = Pco::from_primitive(&values, 0, 128).unwrap();
72
73        let casted = pco
74            .into_array()
75            .cast(DType::Primitive(PType::F64, Nullability::NonNullable))
76            .unwrap();
77        assert_eq!(
78            casted.dtype(),
79            &DType::Primitive(PType::F64, Nullability::NonNullable)
80        );
81
82        assert_arrays_eq!(
83            casted,
84            PrimitiveArray::from_iter([1.0f64, 2.0, 3.0, 4.0, 5.0])
85        );
86    }
87
88    #[test]
89    fn test_cast_pco_nullability_change() {
90        // Test casting from NonNullable to Nullable
91        let values = PrimitiveArray::from_iter([10u32, 20, 30, 40]);
92        let pco = Pco::from_primitive(&values, 0, 128).unwrap();
93
94        let casted = pco
95            .into_array()
96            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
97            .unwrap();
98        assert_arrays_eq!(
99            casted,
100            PrimitiveArray::new(buffer![10u32, 20, 30, 40], Validity::AllValid,)
101        );
102    }
103
104    #[test]
105    fn test_cast_sliced_pco_nullable_to_nonnullable() {
106        let values = PrimitiveArray::new(
107            buffer![10u32, 20, 30, 40, 50, 60],
108            Validity::from_iter([true, true, true, true, true, true]),
109        );
110        let pco = Pco::from_primitive(&values, 0, 128).unwrap();
111        let sliced = pco.slice(1..5).unwrap();
112        let casted = sliced
113            .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
114            .unwrap();
115        assert_eq!(
116            casted.dtype(),
117            &DType::Primitive(PType::U32, Nullability::NonNullable)
118        );
119        // Verify the values are correct
120        assert_arrays_eq!(casted, PrimitiveArray::from_iter([20u32, 30, 40, 50]));
121    }
122
123    #[test]
124    fn test_cast_sliced_pco_part_valid_to_nonnullable() {
125        let values = PrimitiveArray::from_option_iter([
126            None,
127            Some(20u32),
128            Some(30),
129            Some(40),
130            Some(50),
131            Some(60),
132        ]);
133        let pco = Pco::from_primitive(&values, 0, 128).unwrap();
134        let sliced = pco.slice(1..5).unwrap();
135        let casted = sliced
136            .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
137            .unwrap();
138        assert_eq!(
139            casted.dtype(),
140            &DType::Primitive(PType::U32, Nullability::NonNullable)
141        );
142        assert_arrays_eq!(casted, PrimitiveArray::from_iter([20u32, 30, 40, 50]));
143    }
144
145    #[rstest]
146    #[case::f32(PrimitiveArray::new(
147        buffer![1.23f32, 4.56, 7.89, 10.11, 12.13],
148        Validity::NonNullable,
149    ))]
150    #[case::f64(PrimitiveArray::new(
151        buffer![100.1f64, 200.2, 300.3, 400.4, 500.5],
152        Validity::NonNullable,
153    ))]
154    #[case::i32(PrimitiveArray::new(
155        buffer![100i32, 200, 300, 400, 500],
156        Validity::NonNullable,
157    ))]
158    #[case::u64(PrimitiveArray::new(
159        buffer![1000u64, 2000, 3000, 4000],
160        Validity::NonNullable,
161    ))]
162    #[case::single(PrimitiveArray::new(
163        buffer![42.42f64],
164        Validity::NonNullable,
165    ))]
166    fn test_cast_pco_conformance(#[case] values: PrimitiveArray) {
167        let pco = Pco::from_primitive(&values, 0, 128).unwrap();
168        test_cast_conformance(&pco.into_array());
169    }
170}