vortex_fastlanes/rle/compute/
cast.rs1use vortex_array::Array;
5use vortex_array::ArrayRef;
6use vortex_array::builtins::ArrayBuiltins;
7use vortex_array::dtype::DType;
8use vortex_array::scalar_fn::fns::cast::CastReduce;
9use vortex_error::VortexResult;
10
11use crate::rle::RLEArray;
12use crate::rle::RLEVTable;
13
14impl CastReduce for RLEVTable {
15 fn cast(array: &RLEArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
16 let casted_values = array.values().cast(dtype.clone())?;
18
19 let casted_indices = if array.indices().dtype().nullability() != dtype.nullability() {
21 array.indices().cast(DType::Primitive(
22 array.indices().dtype().as_ptype(),
23 dtype.nullability(),
24 ))?
25 } else {
26 array.indices().clone()
27 };
28
29 Ok(Some(unsafe {
30 RLEArray::new_unchecked(
31 casted_values,
32 casted_indices,
33 array.values_idx_offsets().clone(),
34 dtype.clone(),
35 array.offset(),
36 array.len(),
37 )
38 .into()
39 }))
40 }
41}
42
43#[cfg(test)]
44mod tests {
45 use rstest::rstest;
46 use vortex_array::Array;
47 use vortex_array::IntoArray;
48 use vortex_array::arrays::PrimitiveArray;
49 use vortex_array::builtins::ArrayBuiltins;
50 use vortex_array::compute::conformance::cast::test_cast_conformance;
51 use vortex_array::dtype::DType;
52 use vortex_array::dtype::Nullability;
53 use vortex_array::dtype::PType;
54 use vortex_array::validity::Validity;
55 use vortex_buffer::Buffer;
56
57 use crate::rle::RLEArray;
58
59 #[test]
60 fn try_cast_rle_success() {
61 let primitive = PrimitiveArray::new(
62 Buffer::from_iter([10u8, 20, 30, 40, 50]),
63 Validity::from_iter([true, true, true, true, true]),
64 );
65 let rle = RLEArray::encode(&primitive).unwrap();
66
67 let res = rle
68 .to_array()
69 .cast(DType::Primitive(PType::U16, Nullability::NonNullable));
70 assert!(res.is_ok());
71 assert_eq!(
72 res.unwrap().dtype(),
73 &DType::Primitive(PType::U16, Nullability::NonNullable)
74 );
75 }
76
77 #[test]
78 #[should_panic]
79 fn try_cast_rle_fail() {
80 let primitive = PrimitiveArray::new(
81 Buffer::from_iter([10u8, 20, 30, 40, 50]),
82 Validity::from_iter([true, false, true, true, false]),
83 );
84 let rle = RLEArray::encode(&primitive).unwrap();
85 rle.to_array()
86 .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
87 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
88 .unwrap();
89 }
90
91 #[rstest]
92 #[case::u8(
93 PrimitiveArray::new(
94 Buffer::from_iter([0u8, 10, 20, 30, 40, 50]),
95 Validity::NonNullable,
96 )
97 )]
98 #[case::u8_nullable(
99 PrimitiveArray::new(
100 Buffer::from_iter([0u8, 10, 20, 30, 40]),
101 Validity::from_iter([true, false, true, false, true]),
102 )
103 )]
104 #[case::u16(
105 PrimitiveArray::new(
106 Buffer::from_iter([0u16, 100, 200, 300, 400, 500]),
107 Validity::NonNullable,
108 )
109 )]
110 #[case::u16_nullable(
111 PrimitiveArray::new(
112 Buffer::from_iter([0u16, 100, 200, 300, 400]),
113 Validity::from_iter([false, true, false, true, true]),
114 )
115 )]
116 #[case::u32(
117 PrimitiveArray::new(
118 Buffer::from_iter([0u32, 1000, 2000, 3000, 4000]),
119 Validity::NonNullable,
120 )
121 )]
122 #[case::u32_nullable(
123 PrimitiveArray::new(
124 Buffer::from_iter([0u32, 1000, 2000, 3000, 4000]),
125 Validity::from_iter([true, true, false, false, true]),
126 )
127 )]
128 #[case::u64(
129 PrimitiveArray::new(
130 Buffer::from_iter([0u64, 10000, 20000, 30000]),
131 Validity::NonNullable,
132 )
133 )]
134 #[case::u64_nullable(
135 PrimitiveArray::new(
136 Buffer::from_iter([0u64, 10000, 20000, 30000]),
137 Validity::from_iter([false, false, true, true]),
138 )
139 )]
140 fn test_cast_rle_conformance(#[case] primitive: PrimitiveArray) {
141 let rle_array = RLEArray::encode(&primitive).unwrap();
142 test_cast_conformance(rle_array.as_ref());
143 }
144}