vortex_pco/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::compute::CastKernel;
7use vortex_array::compute::CastKernelAdapter;
8use vortex_array::register_kernel;
9use vortex_dtype::DType;
10use vortex_error::VortexResult;
11
12use crate::PcoArray;
13use crate::PcoVTable;
14
15impl CastKernel for PcoVTable {
16    fn cast(&self, array: &PcoArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17        if !dtype.is_nullable() || !array.all_valid() {
18            // TODO(joe): fixme
19            // We cannot cast to non-nullable since the validity containing nulls is used to decode
20            // the PCO array, this would require rewriting tables.
21            return Ok(None);
22        }
23        // PCO (Pcodec) is a compression encoding that stores data in a compressed format.
24        // It can efficiently handle nullability changes without decompression, but type changes
25        // require decompression since the compression algorithm is type-specific.
26        // PCO supports: F16, F32, F64, I16, I32, I64, U16, U32, U64
27        if array.dtype().eq_ignore_nullability(dtype) {
28            // Create a new validity with the target nullability
29            let new_validity = array
30                .unsliced_validity
31                .clone()
32                .cast_nullability(dtype.nullability(), array.len())?;
33
34            return Ok(Some(
35                PcoArray::new(
36                    array.chunk_metas.clone(),
37                    array.pages.clone(),
38                    dtype.clone(),
39                    array.metadata.clone(),
40                    array.unsliced_n_rows(),
41                    new_validity,
42                )
43                ._slice(array.slice_start(), array.slice_stop())
44                .into_array(),
45            ));
46        }
47
48        // For other casts (e.g., numeric type changes), decode to canonical and let PrimitiveArray handle it
49        Ok(None)
50    }
51}
52
53register_kernel!(CastKernelAdapter(PcoVTable).lift());
54
55#[cfg(test)]
56mod tests {
57    use rstest::rstest;
58    use vortex_array::ToCanonical;
59    use vortex_array::arrays::PrimitiveArray;
60    use vortex_array::assert_arrays_eq;
61    use vortex_array::compute::cast;
62    use vortex_array::compute::conformance::cast::test_cast_conformance;
63    use vortex_array::validity::Validity;
64    use vortex_buffer::Buffer;
65    use vortex_dtype::DType;
66    use vortex_dtype::Nullability;
67    use vortex_dtype::PType;
68
69    use crate::PcoArray;
70
71    #[test]
72    fn test_cast_pco_f32_to_f64() {
73        let values = PrimitiveArray::new(
74            Buffer::copy_from(vec![1.0f32, 2.0, 3.0, 4.0, 5.0]),
75            Validity::NonNullable,
76        );
77        let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
78
79        let casted = cast(
80            pco.as_ref(),
81            &DType::Primitive(PType::F64, Nullability::NonNullable),
82        )
83        .unwrap();
84        assert_eq!(
85            casted.dtype(),
86            &DType::Primitive(PType::F64, Nullability::NonNullable)
87        );
88
89        let decoded = casted.to_primitive();
90        let f64_values = decoded.as_slice::<f64>();
91        assert_eq!(f64_values.len(), 5);
92        assert!((f64_values[0] - 1.0).abs() < f64::EPSILON);
93    }
94
95    #[test]
96    fn test_cast_pco_nullability_change() {
97        // Test casting from NonNullable to Nullable
98        let values = PrimitiveArray::new(
99            Buffer::copy_from(vec![10u32, 20, 30, 40]),
100            Validity::NonNullable,
101        );
102        let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
103
104        let casted = cast(
105            pco.as_ref(),
106            &DType::Primitive(PType::U32, Nullability::Nullable),
107        )
108        .unwrap();
109        assert_eq!(
110            casted.dtype(),
111            &DType::Primitive(PType::U32, Nullability::Nullable)
112        );
113    }
114
115    #[test]
116    fn test_cast_sliced_pco_nullable_to_nonnullable() {
117        let values = PrimitiveArray::new(
118            Buffer::copy_from(vec![10u32, 20, 30, 40, 50, 60]),
119            Validity::from_iter([true, true, true, true, true, true]),
120        );
121        let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
122        let sliced = pco.slice(1..5);
123        let casted = cast(
124            sliced.as_ref(),
125            &DType::Primitive(PType::U32, Nullability::NonNullable),
126        )
127        .unwrap();
128        assert_eq!(
129            casted.dtype(),
130            &DType::Primitive(PType::U32, Nullability::NonNullable)
131        );
132        // Verify the values are correct
133        let decoded = casted.to_primitive();
134        let u32_values = decoded.as_slice::<u32>();
135        assert_eq!(u32_values, &[20, 30, 40, 50]);
136    }
137
138    #[test]
139    fn test_cast_sliced_pco_part_valid_to_nonnullable() {
140        let values = PrimitiveArray::from_option_iter([
141            None,
142            Some(20u32),
143            Some(30),
144            Some(40),
145            Some(50),
146            Some(60),
147        ]);
148        let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
149        let sliced = pco.slice(1..5);
150        let casted = cast(
151            sliced.as_ref(),
152            &DType::Primitive(PType::U32, Nullability::NonNullable),
153        )
154        .unwrap();
155        assert_eq!(
156            casted.dtype(),
157            &DType::Primitive(PType::U32, Nullability::NonNullable)
158        );
159        let decoded = casted.to_primitive();
160        let expected = PrimitiveArray::from_iter([20u32, 30, 40, 50]);
161        assert_arrays_eq!(decoded, expected);
162    }
163
164    #[rstest]
165    #[case::f32(PrimitiveArray::new(
166        Buffer::copy_from(vec![1.23f32, 4.56, 7.89, 10.11, 12.13]),
167        Validity::NonNullable,
168    ))]
169    #[case::f64(PrimitiveArray::new(
170        Buffer::copy_from(vec![100.1f64, 200.2, 300.3, 400.4, 500.5]),
171        Validity::NonNullable,
172    ))]
173    #[case::i32(PrimitiveArray::new(
174        Buffer::copy_from(vec![100i32, 200, 300, 400, 500]),
175        Validity::NonNullable,
176    ))]
177    #[case::u64(PrimitiveArray::new(
178        Buffer::copy_from(vec![1000u64, 2000, 3000, 4000]),
179        Validity::NonNullable,
180    ))]
181    #[case::single(PrimitiveArray::new(
182        Buffer::copy_from(vec![42.42f64]),
183        Validity::NonNullable,
184    ))]
185    fn test_cast_pco_conformance(#[case] values: PrimitiveArray) {
186        let pco = PcoArray::from_primitive(&values, 0, 128).unwrap();
187        test_cast_conformance(pco.as_ref());
188    }
189}