vortex_fastlanes/rle/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::compute::CastKernel;
6use vortex_array::compute::CastKernelAdapter;
7use vortex_array::compute::cast;
8use vortex_array::register_kernel;
9use vortex_dtype::DType;
10use vortex_error::VortexResult;
11
12use crate::rle::RLEArray;
13use crate::rle::RLEVTable;
14
15impl CastKernel for RLEVTable {
16    fn cast(&self, array: &RLEArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17        // Cast RLE values.
18        let casted_values = cast(array.values(), dtype)?;
19
20        // Cast RLE indices such that validity matches the target dtype.
21        let casted_indices = if array.indices().dtype().nullability() != dtype.nullability() {
22            cast(
23                array.indices(),
24                &DType::Primitive(array.indices().dtype().as_ptype(), dtype.nullability()),
25            )?
26        } else {
27            array.indices().clone()
28        };
29
30        Ok(Some(unsafe {
31            RLEArray::new_unchecked(
32                casted_values,
33                casted_indices,
34                array.values_idx_offsets().clone(),
35                dtype.clone(),
36                array.offset(),
37                array.len(),
38            )
39            .into()
40        }))
41    }
42}
43
44register_kernel!(CastKernelAdapter(RLEVTable).lift());
45
46#[cfg(test)]
47mod tests {
48    use rstest::rstest;
49    use vortex_array::arrays::PrimitiveArray;
50    use vortex_array::compute::cast;
51    use vortex_array::compute::conformance::cast::test_cast_conformance;
52    use vortex_array::validity::Validity;
53    use vortex_buffer::Buffer;
54    use vortex_dtype::DType;
55    use vortex_dtype::Nullability;
56    use vortex_dtype::PType;
57
58    use crate::rle::RLEArray;
59
60    #[test]
61    fn try_cast_rle_success() {
62        let primitive = PrimitiveArray::new(
63            Buffer::from_iter([10u8, 20, 30, 40, 50]),
64            Validity::from_iter([true, true, true, true, true]),
65        );
66        let rle = RLEArray::encode(&primitive).unwrap();
67
68        let res = cast(
69            rle.as_ref(),
70            &DType::Primitive(PType::U16, Nullability::NonNullable),
71        );
72        assert!(res.is_ok());
73        assert_eq!(
74            res.unwrap().dtype(),
75            &DType::Primitive(PType::U16, Nullability::NonNullable)
76        );
77    }
78
79    #[test]
80    #[should_panic]
81    fn try_cast_rle_fail() {
82        let primitive = PrimitiveArray::new(
83            Buffer::from_iter([10u8, 20, 30, 40, 50]),
84            Validity::from_iter([true, false, true, true, false]),
85        );
86        let rle = RLEArray::encode(&primitive).unwrap();
87        cast(
88            rle.as_ref(),
89            &DType::Primitive(PType::U8, Nullability::NonNullable),
90        )
91        .unwrap();
92    }
93
94    #[rstest]
95    #[case::u8(
96        PrimitiveArray::new(
97            Buffer::from_iter([0u8, 10, 20, 30, 40, 50]),
98            Validity::NonNullable,
99        )
100    )]
101    #[case::u8_nullable(
102        PrimitiveArray::new(
103            Buffer::from_iter([0u8, 10, 20, 30, 40]),
104            Validity::from_iter([true, false, true, false, true]),
105        )
106    )]
107    #[case::u16(
108        PrimitiveArray::new(
109            Buffer::from_iter([0u16, 100, 200, 300, 400, 500]),
110            Validity::NonNullable,
111        )
112    )]
113    #[case::u16_nullable(
114        PrimitiveArray::new(
115            Buffer::from_iter([0u16, 100, 200, 300, 400]),
116            Validity::from_iter([false, true, false, true, true]),
117        )
118    )]
119    #[case::u32(
120        PrimitiveArray::new(
121            Buffer::from_iter([0u32, 1000, 2000, 3000, 4000]),
122            Validity::NonNullable,
123        )
124    )]
125    #[case::u32_nullable(
126        PrimitiveArray::new(
127            Buffer::from_iter([0u32, 1000, 2000, 3000, 4000]),
128            Validity::from_iter([true, true, false, false, true]),
129        )
130    )]
131    #[case::u64(
132        PrimitiveArray::new(
133            Buffer::from_iter([0u64, 10000, 20000, 30000]),
134            Validity::NonNullable,
135        )
136    )]
137    #[case::u64_nullable(
138        PrimitiveArray::new(
139            Buffer::from_iter([0u64, 10000, 20000, 30000]),
140            Validity::from_iter([false, false, true, true]),
141        )
142    )]
143    fn test_cast_rle_conformance(#[case] primitive: PrimitiveArray) {
144        let rle_array = RLEArray::encode(&primitive).unwrap();
145        test_cast_conformance(rle_array.as_ref());
146    }
147}