vortex_pco/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::compute::{CastKernel, CastKernelAdapter};
5use vortex_array::{ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::DType;
7use vortex_error::VortexResult;
8
9use crate::{PcoArray, PcoVTable};
10
11impl CastKernel for PcoVTable {
12    fn cast(&self, array: &PcoArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
13        if !dtype.is_nullable() || !array.all_valid() {
14            // TODO(joe): fixme
15            // We cannot cast to non-nullable since the validity containing nulls is used to decode
16            // the PCO array, this would require rewriting tables.
17            return Ok(None);
18        }
19        // PCO (Pcodec) is a compression encoding that stores data in a compressed format.
20        // It can efficiently handle nullability changes without decompression, but type changes
21        // require decompression since the compression algorithm is type-specific.
22        // PCO supports: F16, F32, F64, I16, I32, I64, U16, U32, U64
23        if array.dtype().eq_ignore_nullability(dtype) {
24            // Create a new validity with the target nullability
25            let new_validity = array
26                .unsliced_validity
27                .clone()
28                .cast_nullability(dtype.nullability(), array.len())?;
29
30            return Ok(Some(
31                PcoArray::new(
32                    array.chunk_metas.clone(),
33                    array.pages.clone(),
34                    dtype.clone(),
35                    array.metadata.clone(),
36                    array.unsliced_n_rows(),
37                    new_validity,
38                )
39                ._slice(array.slice_start(), array.slice_stop())
40                .into_array(),
41            ));
42        }
43
44        // For other casts (e.g., numeric type changes), decode to canonical and let PrimitiveArray handle it
45        Ok(None)
46    }
47}
48
49register_kernel!(CastKernelAdapter(PcoVTable).lift());
50
51#[cfg(test)]
52mod tests {
53    use rstest::rstest;
54    use vortex_array::arrays::PrimitiveArray;
55    use vortex_array::compute::cast;
56    use vortex_array::compute::conformance::cast::test_cast_conformance;
57    use vortex_array::validity::Validity;
58    use vortex_array::{ToCanonical, assert_arrays_eq};
59    use vortex_buffer::Buffer;
60    use vortex_dtype::{DType, Nullability, PType};
61
62    use crate::PcoArray;
63
64    #[test]
65    fn test_cast_pco_f32_to_f64() {
66        let values = PrimitiveArray::new(
67            Buffer::copy_from(vec![1.0f32, 2.0, 3.0, 4.0, 5.0]),
68            Validity::NonNullable,
69        );
70        let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
71
72        let casted = cast(
73            pco.as_ref(),
74            &DType::Primitive(PType::F64, Nullability::NonNullable),
75        )
76        .unwrap();
77        assert_eq!(
78            casted.dtype(),
79            &DType::Primitive(PType::F64, Nullability::NonNullable)
80        );
81
82        let decoded = casted.to_primitive();
83        let f64_values = decoded.as_slice::<f64>();
84        assert_eq!(f64_values.len(), 5);
85        assert!((f64_values[0] - 1.0).abs() < f64::EPSILON);
86    }
87
88    #[test]
89    fn test_cast_pco_nullability_change() {
90        // Test casting from NonNullable to Nullable
91        let values = PrimitiveArray::new(
92            Buffer::copy_from(vec![10u32, 20, 30, 40]),
93            Validity::NonNullable,
94        );
95        let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
96
97        let casted = cast(
98            pco.as_ref(),
99            &DType::Primitive(PType::U32, Nullability::Nullable),
100        )
101        .unwrap();
102        assert_eq!(
103            casted.dtype(),
104            &DType::Primitive(PType::U32, Nullability::Nullable)
105        );
106    }
107
108    #[test]
109    fn test_cast_sliced_pco_nullable_to_nonnullable() {
110        let values = PrimitiveArray::new(
111            Buffer::copy_from(vec![10u32, 20, 30, 40, 50, 60]),
112            Validity::from_iter([true, true, true, true, true, true]),
113        );
114        let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
115        let sliced = pco.slice(1..5);
116        let casted = cast(
117            sliced.as_ref(),
118            &DType::Primitive(PType::U32, Nullability::NonNullable),
119        )
120        .unwrap();
121        assert_eq!(
122            casted.dtype(),
123            &DType::Primitive(PType::U32, Nullability::NonNullable)
124        );
125        // Verify the values are correct
126        let decoded = casted.to_primitive();
127        let u32_values = decoded.as_slice::<u32>();
128        assert_eq!(u32_values, &[20, 30, 40, 50]);
129    }
130
131    #[test]
132    fn test_cast_sliced_pco_part_valid_to_nonnullable() {
133        let values = PrimitiveArray::from_option_iter([
134            None,
135            Some(20u32),
136            Some(30),
137            Some(40),
138            Some(50),
139            Some(60),
140        ]);
141        let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
142        let sliced = pco.slice(1..5);
143        let casted = cast(
144            sliced.as_ref(),
145            &DType::Primitive(PType::U32, Nullability::NonNullable),
146        )
147        .unwrap();
148        assert_eq!(
149            casted.dtype(),
150            &DType::Primitive(PType::U32, Nullability::NonNullable)
151        );
152        let decoded = casted.to_primitive();
153        let expected = PrimitiveArray::from_iter([20u32, 30, 40, 50]);
154        assert_arrays_eq!(decoded, expected);
155    }
156
157    #[rstest]
158    #[case::f32(PrimitiveArray::new(
159        Buffer::copy_from(vec![1.23f32, 4.56, 7.89, 10.11, 12.13]),
160        Validity::NonNullable,
161    ))]
162    #[case::f64(PrimitiveArray::new(
163        Buffer::copy_from(vec![100.1f64, 200.2, 300.3, 400.4, 500.5]),
164        Validity::NonNullable,
165    ))]
166    #[case::i32(PrimitiveArray::new(
167        Buffer::copy_from(vec![100i32, 200, 300, 400, 500]),
168        Validity::NonNullable,
169    ))]
170    #[case::u64(PrimitiveArray::new(
171        Buffer::copy_from(vec![1000u64, 2000, 3000, 4000]),
172        Validity::NonNullable,
173    ))]
174    #[case::single(PrimitiveArray::new(
175        Buffer::copy_from(vec![42.42f64]),
176        Validity::NonNullable,
177    ))]
178    fn test_cast_pco_conformance(#[case] values: PrimitiveArray) {
179        let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
180        test_cast_conformance(pco.as_ref());
181    }
182}