Skip to main content

vortex_pco/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::compute::CastReduce;
7use vortex_dtype::DType;
8use vortex_error::VortexResult;
9
10use crate::PcoArray;
11use crate::PcoVTable;
12
13impl CastReduce for PcoVTable {
14    fn cast(array: &PcoArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
15        if !dtype.is_nullable() || !array.all_valid()? {
16            // TODO(joe): fixme
17            // We cannot cast to non-nullable since the validity containing nulls is used to decode
18            // the PCO array, this would require rewriting tables.
19            return Ok(None);
20        }
21        // PCO (Pcodec) is a compression encoding that stores data in a compressed format.
22        // It can efficiently handle nullability changes without decompression, but type changes
23        // require decompression since the compression algorithm is type-specific.
24        // PCO supports: F16, F32, F64, I16, I32, I64, U16, U32, U64
25        if array.dtype().eq_ignore_nullability(dtype) {
26            // Create a new validity with the target nullability
27            let new_validity = array
28                .unsliced_validity
29                .clone()
30                .cast_nullability(dtype.nullability(), array.len())?;
31
32            return Ok(Some(
33                PcoArray::new(
34                    array.chunk_metas.clone(),
35                    array.pages.clone(),
36                    dtype.clone(),
37                    array.metadata.clone(),
38                    array.unsliced_n_rows(),
39                    new_validity,
40                )
41                ._slice(array.slice_start(), array.slice_stop())
42                .into_array(),
43            ));
44        }
45
46        // For other casts (e.g., numeric type changes), decode to canonical and let PrimitiveArray handle it
47        Ok(None)
48    }
49}
50
51#[cfg(test)]
52mod tests {
53    use rstest::rstest;
54    use vortex_array::arrays::PrimitiveArray;
55    use vortex_array::assert_arrays_eq;
56    use vortex_array::builtins::ArrayBuiltins;
57    use vortex_array::compute::conformance::cast::test_cast_conformance;
58    use vortex_array::validity::Validity;
59    use vortex_buffer::Buffer;
60    use vortex_dtype::DType;
61    use vortex_dtype::Nullability;
62    use vortex_dtype::PType;
63
64    use crate::PcoArray;
65
66    #[test]
67    fn test_cast_pco_f32_to_f64() {
68        let values = PrimitiveArray::new(
69            Buffer::copy_from(vec![1.0f32, 2.0, 3.0, 4.0, 5.0]),
70            Validity::NonNullable,
71        );
72        let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
73
74        let casted = pco
75            .to_array()
76            .cast(DType::Primitive(PType::F64, Nullability::NonNullable))
77            .unwrap();
78        assert_eq!(
79            casted.dtype(),
80            &DType::Primitive(PType::F64, Nullability::NonNullable)
81        );
82
83        assert_arrays_eq!(
84            casted,
85            PrimitiveArray::from_iter([1.0f64, 2.0, 3.0, 4.0, 5.0])
86        );
87    }
88
89    #[test]
90    fn test_cast_pco_nullability_change() {
91        // Test casting from NonNullable to Nullable
92        let values = PrimitiveArray::new(
93            Buffer::copy_from(vec![10u32, 20, 30, 40]),
94            Validity::NonNullable,
95        );
96        let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
97
98        let casted = pco
99            .to_array()
100            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
101            .unwrap();
102        assert_eq!(
103            casted.dtype(),
104            &DType::Primitive(PType::U32, Nullability::Nullable)
105        );
106    }
107
108    #[test]
109    fn test_cast_sliced_pco_nullable_to_nonnullable() {
110        let values = PrimitiveArray::new(
111            Buffer::copy_from(vec![10u32, 20, 30, 40, 50, 60]),
112            Validity::from_iter([true, true, true, true, true, true]),
113        );
114        let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
115        let sliced = pco.slice(1..5).unwrap();
116        let casted = sliced
117            .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
118            .unwrap();
119        assert_eq!(
120            casted.dtype(),
121            &DType::Primitive(PType::U32, Nullability::NonNullable)
122        );
123        // Verify the values are correct
124        assert_arrays_eq!(casted, PrimitiveArray::from_iter([20u32, 30, 40, 50]));
125    }
126
127    #[test]
128    fn test_cast_sliced_pco_part_valid_to_nonnullable() {
129        let values = PrimitiveArray::from_option_iter([
130            None,
131            Some(20u32),
132            Some(30),
133            Some(40),
134            Some(50),
135            Some(60),
136        ]);
137        let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
138        let sliced = pco.slice(1..5).unwrap();
139        let casted = sliced
140            .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
141            .unwrap();
142        assert_eq!(
143            casted.dtype(),
144            &DType::Primitive(PType::U32, Nullability::NonNullable)
145        );
146        assert_arrays_eq!(casted, PrimitiveArray::from_iter([20u32, 30, 40, 50]));
147    }
148
149    #[rstest]
150    #[case::f32(PrimitiveArray::new(
151        Buffer::copy_from(vec![1.23f32, 4.56, 7.89, 10.11, 12.13]),
152        Validity::NonNullable,
153    ))]
154    #[case::f64(PrimitiveArray::new(
155        Buffer::copy_from(vec![100.1f64, 200.2, 300.3, 400.4, 500.5]),
156        Validity::NonNullable,
157    ))]
158    #[case::i32(PrimitiveArray::new(
159        Buffer::copy_from(vec![100i32, 200, 300, 400, 500]),
160        Validity::NonNullable,
161    ))]
162    #[case::u64(PrimitiveArray::new(
163        Buffer::copy_from(vec![1000u64, 2000, 3000, 4000]),
164        Validity::NonNullable,
165    ))]
166    #[case::single(PrimitiveArray::new(
167        Buffer::copy_from(vec![42.42f64]),
168        Validity::NonNullable,
169    ))]
170    fn test_cast_pco_conformance(#[case] values: PrimitiveArray) {
171        let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
172        test_cast_conformance(pco.as_ref());
173    }
174}