vortex_array/arrays/primitive/compute/
cast.rs1use vortex_buffer::Buffer;
5use vortex_buffer::BufferMut;
6use vortex_dtype::DType;
7use vortex_dtype::NativePType;
8use vortex_dtype::match_each_native_ptype;
9use vortex_error::VortexResult;
10use vortex_error::vortex_err;
11use vortex_mask::AllOr;
12use vortex_mask::Mask;
13
14use crate::ArrayRef;
15use crate::ExecutionCtx;
16use crate::IntoArray;
17use crate::arrays::PrimitiveVTable;
18use crate::arrays::primitive::PrimitiveArray;
19use crate::compute::CastKernel;
20use crate::vtable::ValidityHelper;
21
22impl CastKernel for PrimitiveVTable {
23 fn cast(
24 array: &PrimitiveArray,
25 dtype: &DType,
26 _ctx: &mut ExecutionCtx,
27 ) -> VortexResult<Option<ArrayRef>> {
28 let DType::Primitive(new_ptype, new_nullability) = dtype else {
29 return Ok(None);
30 };
31 let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
32
33 let new_validity = array
35 .validity()
36 .clone()
37 .cast_nullability(new_nullability, array.len())?;
38
39 if array.ptype() == new_ptype {
41 return Ok(Some(unsafe {
43 PrimitiveArray::new_unchecked_from_handle(
44 array.buffer_handle().clone(),
45 array.ptype(),
46 new_validity,
47 )
48 .into_array()
49 }));
50 }
51
52 let mask = array.validity_mask()?;
53
54 Ok(Some(match_each_native_ptype!(new_ptype, |T| {
56 match_each_native_ptype!(array.ptype(), |F| {
57 PrimitiveArray::new(cast::<F, T>(array.as_slice(), mask)?, new_validity)
58 .into_array()
59 })
60 })))
61 }
62}
63
64fn cast<F: NativePType, T: NativePType>(array: &[F], mask: Mask) -> VortexResult<Buffer<T>> {
65 match mask.bit_buffer() {
66 AllOr::All => {
67 let mut buffer = BufferMut::with_capacity(array.len());
68 for item in array {
69 let item = T::from(*item).ok_or_else(
70 || vortex_err!(Compute: "Failed to cast {} to {:?}", item, T::PTYPE),
71 )?;
72 unsafe { buffer.push_unchecked(item) }
74 }
75 Ok(buffer.freeze())
76 }
77 AllOr::None => Ok(Buffer::zeroed(array.len())),
78 AllOr::Some(b) => {
79 let mut buffer = BufferMut::with_capacity(array.len());
81 for (item, valid) in array.iter().zip(b.iter()) {
82 if valid {
83 let item = T::from(*item).ok_or_else(
84 || vortex_err!(Compute: "Failed to cast {} to {:?}", item, T::PTYPE),
85 )?;
86 unsafe { buffer.push_unchecked(item) }
88 } else {
89 unsafe { buffer.push_unchecked(T::default()) }
91 }
92 }
93 Ok(buffer.freeze())
94 }
95 }
96}
97
98#[cfg(test)]
99mod test {
100 use rstest::rstest;
101 use vortex_buffer::BitBuffer;
102 use vortex_buffer::buffer;
103 use vortex_dtype::DType;
104 use vortex_dtype::Nullability;
105 use vortex_dtype::PType;
106 use vortex_error::VortexError;
107 use vortex_mask::Mask;
108
109 use crate::IntoArray;
110 use crate::arrays::PrimitiveArray;
111 use crate::assert_arrays_eq;
112 use crate::builtins::ArrayBuiltins;
113 use crate::canonical::ToCanonical;
114 use crate::compute::conformance::cast::test_cast_conformance;
115 use crate::validity::Validity;
116 use crate::vtable::ValidityHelper;
117
118 #[test]
119 fn cast_u32_u8() {
120 let arr = buffer![0u32, 10, 200].into_array();
121
122 let p = arr.cast(PType::U8.into()).unwrap().to_primitive();
124 assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
125 assert_eq!(p.validity(), &Validity::NonNullable);
126
127 let p = p
129 .to_array()
130 .cast(DType::Primitive(PType::U8, Nullability::Nullable))
131 .unwrap()
132 .to_primitive();
133 assert_arrays_eq!(
134 p,
135 PrimitiveArray::new(buffer![0u8, 10, 200], Validity::AllValid)
136 );
137 assert_eq!(p.validity(), &Validity::AllValid);
138
139 let p = p
141 .to_array()
142 .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
143 .unwrap()
144 .to_primitive();
145 assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
146 assert_eq!(p.validity(), &Validity::NonNullable);
147
148 let p = p
150 .to_array()
151 .cast(DType::Primitive(PType::U32, Nullability::Nullable))
152 .unwrap()
153 .to_primitive();
154 assert_arrays_eq!(
155 p,
156 PrimitiveArray::new(buffer![0u32, 10, 200], Validity::AllValid)
157 );
158 assert_eq!(p.validity(), &Validity::AllValid);
159
160 let p = p
162 .to_array()
163 .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
164 .unwrap()
165 .to_primitive();
166 assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
167 assert_eq!(p.validity(), &Validity::NonNullable);
168 }
169
170 #[test]
171 fn cast_u32_f32() {
172 let arr = buffer![0u32, 10, 200].into_array();
173 let u8arr = arr.cast(PType::F32.into()).unwrap().to_primitive();
174 assert_arrays_eq!(u8arr, PrimitiveArray::from_iter([0.0f32, 10., 200.]));
175 }
176
177 #[test]
178 fn cast_i32_u32() {
179 let arr = buffer![-1i32].into_array();
180 let error = arr
181 .cast(PType::U32.into())
182 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
183 .unwrap_err();
184 assert!(matches!(error, VortexError::Compute(..)));
185 assert!(error.to_string().contains("Failed to cast -1 to U32"));
186 }
187
188 #[test]
189 fn cast_array_with_nulls_to_nonnullable() {
190 let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
191 let err = arr
192 .to_array()
193 .cast(PType::I32.into())
194 .and_then(|a| a.to_canonical().map(|c| c.into_array()))
195 .unwrap_err();
196
197 assert!(matches!(err, VortexError::InvalidArgument(..)));
198 assert!(
199 err.to_string()
200 .contains("Cannot cast array with invalid values to non-nullable type.")
201 );
202 }
203
204 #[test]
205 fn cast_with_invalid_nulls() {
206 let arr = PrimitiveArray::new(
207 buffer![-1i32, 0, 10],
208 Validity::from_iter([false, true, true]),
209 );
210 let p = arr
211 .to_array()
212 .cast(DType::Primitive(PType::U32, Nullability::Nullable))
213 .unwrap()
214 .to_primitive();
215 assert_arrays_eq!(
216 p,
217 PrimitiveArray::from_option_iter([None, Some(0u32), Some(10)])
218 );
219 assert_eq!(
220 p.validity_mask().unwrap(),
221 Mask::from(BitBuffer::from(vec![false, true, true]))
222 );
223 }
224
225 #[rstest]
226 #[case(buffer![0u8, 1, 2, 3, 255].into_array())]
227 #[case(buffer![0u16, 100, 1000, 65535].into_array())]
228 #[case(buffer![0u32, 100, 1000, 1000000].into_array())]
229 #[case(buffer![0u64, 100, 1000, 1000000000].into_array())]
230 #[case(buffer![-128i8, -1, 0, 1, 127].into_array())]
231 #[case(buffer![-1000i16, -1, 0, 1, 1000].into_array())]
232 #[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())]
233 #[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())]
234 #[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())]
235 #[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())]
236 #[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())]
237 #[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())]
238 #[case(buffer![42u32].into_array())]
239 fn test_cast_primitive_conformance(#[case] array: crate::ArrayRef) {
240 test_cast_conformance(array.as_ref());
241 }
242}