vortex_fastlanes/bitpacking/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::compute::CastKernel;
7use vortex_array::compute::CastKernelAdapter;
8use vortex_array::compute::cast;
9use vortex_array::patches::Patches;
10use vortex_array::register_kernel;
11use vortex_array::vtable::ValidityHelper;
12use vortex_dtype::DType;
13use vortex_error::VortexResult;
14
15use crate::bitpacking::BitPackedArray;
16use crate::bitpacking::BitPackedVTable;
17
18impl CastKernel for BitPackedVTable {
19    fn cast(&self, array: &BitPackedArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
20        if array.dtype().eq_ignore_nullability(dtype) {
21            let new_validity = array
22                .validity()
23                .clone()
24                .cast_nullability(dtype.nullability(), array.len())?;
25            return Ok(Some(
26                BitPackedArray::try_new(
27                    array.packed().clone(),
28                    dtype.as_ptype(),
29                    new_validity,
30                    array
31                        .patches()
32                        .map(|patches| {
33                            let new_values = cast(patches.values(), dtype)?;
34                            VortexResult::Ok(Patches::new(
35                                patches.array_len(),
36                                patches.offset(),
37                                patches.indices().clone(),
38                                new_values,
39                                patches.chunk_offsets().clone(),
40                            ))
41                        })
42                        .transpose()?,
43                    array.bit_width(),
44                    array.len(),
45                    array.offset(),
46                )?
47                .into_array(),
48            ));
49        }
50
51        Ok(None)
52    }
53}
54
55register_kernel!(CastKernelAdapter(BitPackedVTable).lift());
56
57#[cfg(test)]
58mod tests {
59    use rstest::rstest;
60    use vortex_array::IntoArray;
61    use vortex_array::ToCanonical;
62    use vortex_array::arrays::PrimitiveArray;
63    use vortex_array::assert_arrays_eq;
64    use vortex_array::compute::cast;
65    use vortex_array::compute::conformance::cast::test_cast_conformance;
66    use vortex_buffer::buffer;
67    use vortex_dtype::DType;
68    use vortex_dtype::Nullability;
69    use vortex_dtype::PType;
70
71    use crate::BitPackedArray;
72
73    #[test]
74    fn test_cast_bitpacked_u8_to_u32() {
75        let packed =
76            BitPackedArray::encode(buffer![10u8, 20, 30, 40, 50, 60].into_array().as_ref(), 6)
77                .unwrap();
78
79        let casted = cast(
80            packed.as_ref(),
81            &DType::Primitive(PType::U32, Nullability::NonNullable),
82        )
83        .unwrap();
84        assert_eq!(
85            casted.dtype(),
86            &DType::Primitive(PType::U32, Nullability::NonNullable)
87        );
88
89        let decoded = casted.to_primitive();
90        assert_arrays_eq!(
91            decoded.as_ref(),
92            PrimitiveArray::from_iter([10u32, 20, 30, 40, 50, 60])
93        );
94    }
95
96    #[test]
97    fn test_cast_bitpacked_nullable() {
98        let values = PrimitiveArray::from_option_iter([Some(5u16), None, Some(10), Some(15), None]);
99        let packed = BitPackedArray::encode(values.as_ref(), 4).unwrap();
100
101        let casted = cast(
102            packed.as_ref(),
103            &DType::Primitive(PType::U32, Nullability::Nullable),
104        )
105        .unwrap();
106        assert_eq!(
107            casted.dtype(),
108            &DType::Primitive(PType::U32, Nullability::Nullable)
109        );
110    }
111
112    #[rstest]
113    #[case(BitPackedArray::encode(buffer![0u8, 10, 20, 30, 40, 50, 60, 63].into_array().as_ref(), 6).unwrap())]
114    #[case(BitPackedArray::encode(buffer![0u16, 100, 200, 300, 400, 500].into_array().as_ref(), 9).unwrap())]
115    #[case(BitPackedArray::encode(buffer![0u32, 1000, 2000, 3000, 4000].into_array().as_ref(), 12).unwrap())]
116    #[case(BitPackedArray::encode(PrimitiveArray::from_option_iter([Some(1u32), None, Some(7), Some(15), None]).as_ref(), 4).unwrap())]
117    fn test_cast_bitpacked_conformance(#[case] array: BitPackedArray) {
118        test_cast_conformance(array.as_ref());
119    }
120}