Skip to main content

vortex_fastlanes/rle/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::Array;
5use vortex_array::ArrayRef;
6use vortex_array::builtins::ArrayBuiltins;
7use vortex_array::dtype::DType;
8use vortex_array::scalar_fn::fns::cast::CastReduce;
9use vortex_error::VortexResult;
10
11use crate::rle::RLEArray;
12use crate::rle::RLEVTable;
13
14impl CastReduce for RLEVTable {
15    fn cast(array: &RLEArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
16        // Cast RLE values.
17        let casted_values = array.values().cast(dtype.clone())?;
18
19        // Cast RLE indices such that validity matches the target dtype.
20        let casted_indices = if array.indices().dtype().nullability() != dtype.nullability() {
21            array.indices().cast(DType::Primitive(
22                array.indices().dtype().as_ptype(),
23                dtype.nullability(),
24            ))?
25        } else {
26            array.indices().clone()
27        };
28
29        Ok(Some(unsafe {
30            RLEArray::new_unchecked(
31                casted_values,
32                casted_indices,
33                array.values_idx_offsets().clone(),
34                dtype.clone(),
35                array.offset(),
36                array.len(),
37            )
38            .into()
39        }))
40    }
41}
42
43#[cfg(test)]
44mod tests {
45    use rstest::rstest;
46    use vortex_array::Array;
47    use vortex_array::IntoArray;
48    use vortex_array::arrays::PrimitiveArray;
49    use vortex_array::builtins::ArrayBuiltins;
50    use vortex_array::compute::conformance::cast::test_cast_conformance;
51    use vortex_array::dtype::DType;
52    use vortex_array::dtype::Nullability;
53    use vortex_array::dtype::PType;
54    use vortex_array::validity::Validity;
55    use vortex_buffer::Buffer;
56
57    use crate::rle::RLEArray;
58
59    #[test]
60    fn try_cast_rle_success() {
61        let primitive = PrimitiveArray::new(
62            Buffer::from_iter([10u8, 20, 30, 40, 50]),
63            Validity::from_iter([true, true, true, true, true]),
64        );
65        let rle = RLEArray::encode(&primitive).unwrap();
66
67        let res = rle
68            .to_array()
69            .cast(DType::Primitive(PType::U16, Nullability::NonNullable));
70        assert!(res.is_ok());
71        assert_eq!(
72            res.unwrap().dtype(),
73            &DType::Primitive(PType::U16, Nullability::NonNullable)
74        );
75    }
76
77    #[test]
78    #[should_panic]
79    fn try_cast_rle_fail() {
80        let primitive = PrimitiveArray::new(
81            Buffer::from_iter([10u8, 20, 30, 40, 50]),
82            Validity::from_iter([true, false, true, true, false]),
83        );
84        let rle = RLEArray::encode(&primitive).unwrap();
85        rle.to_array()
86            .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
87            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
88            .unwrap();
89    }
90
91    #[rstest]
92    #[case::u8(
93        PrimitiveArray::new(
94            Buffer::from_iter([0u8, 10, 20, 30, 40, 50]),
95            Validity::NonNullable,
96        )
97    )]
98    #[case::u8_nullable(
99        PrimitiveArray::new(
100            Buffer::from_iter([0u8, 10, 20, 30, 40]),
101            Validity::from_iter([true, false, true, false, true]),
102        )
103    )]
104    #[case::u16(
105        PrimitiveArray::new(
106            Buffer::from_iter([0u16, 100, 200, 300, 400, 500]),
107            Validity::NonNullable,
108        )
109    )]
110    #[case::u16_nullable(
111        PrimitiveArray::new(
112            Buffer::from_iter([0u16, 100, 200, 300, 400]),
113            Validity::from_iter([false, true, false, true, true]),
114        )
115    )]
116    #[case::u32(
117        PrimitiveArray::new(
118            Buffer::from_iter([0u32, 1000, 2000, 3000, 4000]),
119            Validity::NonNullable,
120        )
121    )]
122    #[case::u32_nullable(
123        PrimitiveArray::new(
124            Buffer::from_iter([0u32, 1000, 2000, 3000, 4000]),
125            Validity::from_iter([true, true, false, false, true]),
126        )
127    )]
128    #[case::u64(
129        PrimitiveArray::new(
130            Buffer::from_iter([0u64, 10000, 20000, 30000]),
131            Validity::NonNullable,
132        )
133    )]
134    #[case::u64_nullable(
135        PrimitiveArray::new(
136            Buffer::from_iter([0u64, 10000, 20000, 30000]),
137            Validity::from_iter([false, false, true, true]),
138        )
139    )]
140    fn test_cast_rle_conformance(#[case] primitive: PrimitiveArray) {
141        let rle_array = RLEArray::encode(&primitive).unwrap();
142        test_cast_conformance(rle_array.as_ref());
143    }
144}