Skip to main content

vortex_fastlanes/bitpacking/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::builtins::ArrayBuiltins;
7use vortex_array::dtype::DType;
8use vortex_array::patches::Patches;
9use vortex_array::scalar_fn::fns::cast::CastReduce;
10use vortex_array::vtable::ValidityHelper;
11use vortex_error::VortexResult;
12
13use crate::bitpacking::BitPackedArray;
14use crate::bitpacking::BitPackedVTable;
15
16impl CastReduce for BitPackedVTable {
17    fn cast(array: &BitPackedArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
18        if array.dtype().eq_ignore_nullability(dtype) {
19            let new_validity = array
20                .validity()
21                .clone()
22                .cast_nullability(dtype.nullability(), array.len())?;
23            return Ok(Some(
24                BitPackedArray::try_new(
25                    array.packed().clone(),
26                    dtype.as_ptype(),
27                    new_validity,
28                    array
29                        .patches()
30                        .map(|patches| {
31                            let new_values = patches.values().cast(dtype.clone())?;
32                            Patches::new(
33                                patches.array_len(),
34                                patches.offset(),
35                                patches.indices().clone(),
36                                new_values,
37                                patches.chunk_offsets().clone(),
38                            )
39                        })
40                        .transpose()?,
41                    array.bit_width(),
42                    array.len(),
43                    array.offset(),
44                )?
45                .into_array(),
46            ));
47        }
48
49        Ok(None)
50    }
51}
52
53#[cfg(test)]
54mod tests {
55    use rstest::rstest;
56    use vortex_array::IntoArray;
57    use vortex_array::arrays::PrimitiveArray;
58    use vortex_array::assert_arrays_eq;
59    use vortex_array::builtins::ArrayBuiltins;
60    use vortex_array::compute::conformance::cast::test_cast_conformance;
61    use vortex_array::dtype::DType;
62    use vortex_array::dtype::Nullability;
63    use vortex_array::dtype::PType;
64    use vortex_buffer::buffer;
65
66    use crate::BitPackedArray;
67
68    #[test]
69    fn test_cast_bitpacked_u8_to_u32() {
70        let packed =
71            BitPackedArray::encode(buffer![10u8, 20, 30, 40, 50, 60].into_array().as_ref(), 6)
72                .unwrap();
73
74        let casted = packed
75            .to_array()
76            .cast(DType::Primitive(PType::U32, Nullability::NonNullable))
77            .unwrap();
78        assert_eq!(
79            casted.dtype(),
80            &DType::Primitive(PType::U32, Nullability::NonNullable)
81        );
82
83        assert_arrays_eq!(
84            casted.as_ref(),
85            PrimitiveArray::from_iter([10u32, 20, 30, 40, 50, 60])
86        );
87    }
88
89    #[test]
90    fn test_cast_bitpacked_nullable() {
91        let values = PrimitiveArray::from_option_iter([Some(5u16), None, Some(10), Some(15), None]);
92        let packed = BitPackedArray::encode(values.as_ref(), 4).unwrap();
93
94        let casted = packed
95            .to_array()
96            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
97            .unwrap();
98        assert_eq!(
99            casted.dtype(),
100            &DType::Primitive(PType::U32, Nullability::Nullable)
101        );
102    }
103
104    #[rstest]
105    #[case(BitPackedArray::encode(buffer![0u8, 10, 20, 30, 40, 50, 60, 63].into_array().as_ref(), 6).unwrap())]
106    #[case(BitPackedArray::encode(buffer![0u16, 100, 200, 300, 400, 500].into_array().as_ref(), 9).unwrap())]
107    #[case(BitPackedArray::encode(buffer![0u32, 1000, 2000, 3000, 4000].into_array().as_ref(), 12).unwrap())]
108    #[case(BitPackedArray::encode(PrimitiveArray::from_option_iter([Some(1u32), None, Some(7), Some(15), None]).as_ref(), 4).unwrap())]
109    fn test_cast_bitpacked_conformance(#[case] array: BitPackedArray) {
110        test_cast_conformance(array.as_ref());
111    }
112}