vortex_array/arrays/primitive/compute/
cast.rs1use vortex_buffer::Buffer;
5use vortex_buffer::BufferMut;
6use vortex_dtype::DType;
7use vortex_dtype::NativePType;
8use vortex_dtype::match_each_native_ptype;
9use vortex_error::VortexResult;
10use vortex_error::vortex_err;
11use vortex_mask::AllOr;
12use vortex_mask::Mask;
13
14use crate::ArrayRef;
15use crate::IntoArray;
16use crate::arrays::PrimitiveVTable;
17use crate::arrays::primitive::PrimitiveArray;
18use crate::compute::CastKernel;
19use crate::compute::CastKernelAdapter;
20use crate::register_kernel;
21use crate::vtable::ValidityHelper;
22
23impl CastKernel for PrimitiveVTable {
24 fn cast(&self, array: &PrimitiveArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
25 let DType::Primitive(new_ptype, new_nullability) = dtype else {
26 return Ok(None);
27 };
28 let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
29
30 let new_validity = array
32 .validity()
33 .clone()
34 .cast_nullability(new_nullability, array.len())?;
35
36 if array.ptype() == new_ptype {
38 return Ok(Some(
39 PrimitiveArray::from_byte_buffer(
40 array.byte_buffer().clone(),
41 array.ptype(),
42 new_validity,
43 )
44 .into_array(),
45 ));
46 }
47
48 let mask = array.validity_mask();
49
50 Ok(Some(match_each_native_ptype!(new_ptype, |T| {
52 match_each_native_ptype!(array.ptype(), |F| {
53 PrimitiveArray::new(cast::<F, T>(array.as_slice(), mask)?, new_validity)
54 .into_array()
55 })
56 })))
57 }
58}
59
60register_kernel!(CastKernelAdapter(PrimitiveVTable).lift());
61
62fn cast<F: NativePType, T: NativePType>(array: &[F], mask: Mask) -> VortexResult<Buffer<T>> {
63 match mask.bit_buffer() {
64 AllOr::All => {
65 let mut buffer = BufferMut::with_capacity(array.len());
66 for item in array {
67 let item = T::from(*item).ok_or_else(
68 || vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE),
69 )?;
70 unsafe { buffer.push_unchecked(item) }
72 }
73 Ok(buffer.freeze())
74 }
75 AllOr::None => Ok(Buffer::zeroed(array.len())),
76 AllOr::Some(b) => {
77 let mut buffer = BufferMut::with_capacity(array.len());
79 for (item, valid) in array.iter().zip(b.iter()) {
80 if valid {
81 let item = T::from(*item).ok_or_else(
82 || vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE),
83 )?;
84 unsafe { buffer.push_unchecked(item) }
86 } else {
87 unsafe { buffer.push_unchecked(T::default()) }
89 }
90 }
91 Ok(buffer.freeze())
92 }
93 }
94}
95
96#[cfg(test)]
97mod test {
98 use rstest::rstest;
99 use vortex_buffer::BitBuffer;
100 use vortex_buffer::buffer;
101 use vortex_dtype::DType;
102 use vortex_dtype::Nullability;
103 use vortex_dtype::PType;
104 use vortex_error::VortexError;
105 use vortex_mask::Mask;
106
107 use crate::IntoArray;
108 use crate::arrays::PrimitiveArray;
109 use crate::canonical::ToCanonical;
110 use crate::compute::cast;
111 use crate::compute::conformance::cast::test_cast_conformance;
112 use crate::validity::Validity;
113 use crate::vtable::ValidityHelper;
114
115 #[test]
116 fn cast_u32_u8() {
117 let arr = buffer![0u32, 10, 200].into_array();
118
119 let p = cast(&arr, PType::U8.into()).unwrap().to_primitive();
121 assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
122 assert_eq!(p.validity(), &Validity::NonNullable);
123
124 let p = cast(
126 p.as_ref(),
127 &DType::Primitive(PType::U8, Nullability::Nullable),
128 )
129 .unwrap()
130 .to_primitive();
131 assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
132 assert_eq!(p.validity(), &Validity::AllValid);
133
134 let p = cast(
136 p.as_ref(),
137 &DType::Primitive(PType::U8, Nullability::NonNullable),
138 )
139 .unwrap()
140 .to_primitive();
141 assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
142 assert_eq!(p.validity(), &Validity::NonNullable);
143
144 let p = cast(
146 p.as_ref(),
147 &DType::Primitive(PType::U32, Nullability::Nullable),
148 )
149 .unwrap()
150 .to_primitive();
151 assert_eq!(p.as_slice::<u32>(), vec![0u32, 10, 200]);
152 assert_eq!(p.validity(), &Validity::AllValid);
153
154 let p = cast(
156 p.as_ref(),
157 &DType::Primitive(PType::U8, Nullability::NonNullable),
158 )
159 .unwrap()
160 .to_primitive();
161 assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
162 assert_eq!(p.validity(), &Validity::NonNullable);
163 }
164
165 #[test]
166 fn cast_u32_f32() {
167 let arr = buffer![0u32, 10, 200].into_array();
168 let u8arr = cast(&arr, PType::F32.into()).unwrap().to_primitive();
169 assert_eq!(u8arr.as_slice::<f32>(), vec![0.0f32, 10., 200.]);
170 }
171
172 #[test]
173 fn cast_i32_u32() {
174 let arr = buffer![-1i32].into_array();
175 let error = cast(&arr, PType::U32.into()).err().unwrap();
176 let VortexError::ComputeError(s, _) = error else {
177 unreachable!()
178 };
179 assert_eq!(s.to_string(), "Failed to cast -1 to U32");
180 }
181
182 #[test]
183 fn cast_array_with_nulls_to_nonnullable() {
184 let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
185 let err = cast(arr.as_ref(), PType::I32.into()).unwrap_err();
186 let VortexError::InvalidArgument(s, _) = err else {
187 unreachable!()
188 };
189 assert_eq!(
190 s.to_string(),
191 "Cannot cast array with invalid values to non-nullable type."
192 );
193 }
194
195 #[test]
196 fn cast_with_invalid_nulls() {
197 let arr = PrimitiveArray::new(
198 buffer![-1i32, 0, 10],
199 Validity::from_iter([false, true, true]),
200 );
201 let p = cast(
202 arr.as_ref(),
203 &DType::Primitive(PType::U32, Nullability::Nullable),
204 )
205 .unwrap()
206 .to_primitive();
207 assert_eq!(p.as_slice::<u32>(), vec![0, 0, 10]);
208 assert_eq!(
209 p.validity_mask(),
210 Mask::from(BitBuffer::from(vec![false, true, true]))
211 );
212 }
213
214 #[rstest]
215 #[case(buffer![0u8, 1, 2, 3, 255].into_array())]
216 #[case(buffer![0u16, 100, 1000, 65535].into_array())]
217 #[case(buffer![0u32, 100, 1000, 1000000].into_array())]
218 #[case(buffer![0u64, 100, 1000, 1000000000].into_array())]
219 #[case(buffer![-128i8, -1, 0, 1, 127].into_array())]
220 #[case(buffer![-1000i16, -1, 0, 1, 1000].into_array())]
221 #[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())]
222 #[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())]
223 #[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())]
224 #[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())]
225 #[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())]
226 #[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())]
227 #[case(buffer![42u32].into_array())]
228 fn test_cast_primitive_conformance(#[case] array: crate::ArrayRef) {
229 test_cast_conformance(array.as_ref());
230 }
231}