Skip to main content

vortex_pco/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::IntoArray;
7use vortex_array::LEGACY_SESSION;
8use vortex_array::VortexSessionExecute;
9use vortex_array::dtype::DType;
10use vortex_array::scalar_fn::fns::cast::CastReduce;
11use vortex_array::vtable::child_to_validity;
12use vortex_error::VortexResult;
13
14use crate::Pco;
15use crate::PcoData;
16impl CastReduce for Pco {
17    fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
18        if !dtype.is_nullable()
19            || !array
20                .array()
21                .all_valid(&mut LEGACY_SESSION.create_execution_ctx())?
22        {
23            // TODO(joe): fixme
24            // We cannot cast to non-nullable since the validity containing nulls is used to decode
25            // the PCO array, this would require rewriting tables.
26            return Ok(None);
27        }
28        // PCO (Pcodec) is a compression encoding that stores data in a compressed format.
29        // It can efficiently handle nullability changes without decompression, but type changes
30        // require decompression since the compression algorithm is type-specific.
31        // PCO supports: F16, F32, F64, I16, I32, I64, U16, U32, U64
32        if array.dtype().eq_ignore_nullability(dtype) {
33            // Create a new validity with the target nullability
34            let unsliced_validity =
35                child_to_validity(&array.slots()[0], array.dtype().nullability());
36            let new_validity =
37                unsliced_validity.cast_nullability(dtype.nullability(), array.len())?;
38
39            let data = PcoData::new(
40                array.chunk_metas.clone(),
41                array.pages.clone(),
42                dtype.as_ptype(),
43                array.metadata.clone(),
44                array.unsliced_n_rows(),
45            )
46            ._slice(array.slice_start(), array.slice_stop());
47
48            return Ok(Some(
49                Pco::try_new(dtype.clone(), data, new_validity)?.into_array(),
50            ));
51        }
52
53        // For other casts (e.g., numeric type changes), decode to canonical and let PrimitiveArray handle it
54        Ok(None)
55    }
56}
57
58#[cfg(test)]
59mod tests {
60    use rstest::rstest;
61    use vortex_array::IntoArray;
62    use vortex_array::arrays::PrimitiveArray;
63    use vortex_array::assert_arrays_eq;
64    use vortex_array::builtins::ArrayBuiltins;
65    use vortex_array::compute::conformance::cast::test_cast_conformance;
66    use vortex_array::dtype::DType;
67    use vortex_array::dtype::Nullability;
68    use vortex_array::dtype::PType;
69    use vortex_array::validity::Validity;
70    use vortex_buffer::buffer;
71
72    use crate::Pco;
73
74    #[test]
75    fn test_cast_pco_f32_to_f64() {
76        let values = PrimitiveArray::from_iter([1.0f32, 2.0, 3.0, 4.0, 5.0]);
77        let pco = Pco::from_primitive(values.as_view(), 0, 128).unwrap();
78
79        let casted = pco
80            .into_array()
81            .cast(DType::Primitive(PType::F64, Nullability::NonNullable))
82            .unwrap();
83        assert_eq!(
84            casted.dtype(),
85            &DType::Primitive(PType::F64, Nullability::NonNullable)
86        );
87
88        assert_arrays_eq!(
89            casted,
90            PrimitiveArray::from_iter([1.0f64, 2.0, 3.0, 4.0, 5.0])
91        );
92    }
93
94    #[test]
95    fn test_cast_pco_nullability_change() {
96        // Test casting from NonNullable to Nullable
97        let values = PrimitiveArray::from_iter([10u32, 20, 30, 40]);
98        let pco = Pco::from_primitive(values.as_view(), 0, 128).unwrap();
99
100        let casted = pco
101            .into_array()
102            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
103            .unwrap();
104        assert_arrays_eq!(
105            casted,
106            PrimitiveArray::new(buffer![10u32, 20, 30, 40], Validity::AllValid,)
107        );
108    }
109
110    #[test]
111    fn test_cast_sliced_pco_nullable_to_nonnullable() {
112        let values = PrimitiveArray::new(
113            buffer![10u32, 20, 30, 40, 50, 60],
114            Validity::from_iter([true, true, true, true, true, true]),
115        );
116        let pco = Pco::from_primitive(values.as_view(), 0, 128).unwrap();
117        let sliced = pco.slice(1..5).unwrap();
118        let casted = sliced
119            .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
120            .unwrap();
121        assert_eq!(
122            casted.dtype(),
123            &DType::Primitive(PType::U32, Nullability::NonNullable)
124        );
125        // Verify the values are correct
126        assert_arrays_eq!(casted, PrimitiveArray::from_iter([20u32, 30, 40, 50]));
127    }
128
129    #[test]
130    fn test_cast_sliced_pco_part_valid_to_nonnullable() {
131        let values = PrimitiveArray::from_option_iter([
132            None,
133            Some(20u32),
134            Some(30),
135            Some(40),
136            Some(50),
137            Some(60),
138        ]);
139        let pco = Pco::from_primitive(values.as_view(), 0, 128).unwrap();
140        let sliced = pco.slice(1..5).unwrap();
141        let casted = sliced
142            .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
143            .unwrap();
144        assert_eq!(
145            casted.dtype(),
146            &DType::Primitive(PType::U32, Nullability::NonNullable)
147        );
148        assert_arrays_eq!(casted, PrimitiveArray::from_iter([20u32, 30, 40, 50]));
149    }
150
151    #[rstest]
152    #[case::f32(PrimitiveArray::new(
153        buffer![1.23f32, 4.56, 7.89, 10.11, 12.13],
154        Validity::NonNullable,
155    ))]
156    #[case::f64(PrimitiveArray::new(
157        buffer![100.1f64, 200.2, 300.3, 400.4, 500.5],
158        Validity::NonNullable,
159    ))]
160    #[case::i32(PrimitiveArray::new(
161        buffer![100i32, 200, 300, 400, 500],
162        Validity::NonNullable,
163    ))]
164    #[case::u64(PrimitiveArray::new(
165        buffer![1000u64, 2000, 3000, 4000],
166        Validity::NonNullable,
167    ))]
168    #[case::single(PrimitiveArray::new(
169        buffer![42.42f64],
170        Validity::NonNullable,
171    ))]
172    fn test_cast_pco_conformance(#[case] values: PrimitiveArray) {
173        let pco = Pco::from_primitive(values.as_view(), 0, 128).unwrap();
174        test_cast_conformance(&pco.into_array());
175    }
176}