vortex_array/arrays/primitive/compute/
cast.rs1use vortex_buffer::{Buffer, BufferMut};
2use vortex_dtype::{DType, NativePType, Nullability, match_each_native_ptype};
3use vortex_error::{VortexResult, vortex_bail, vortex_err};
4
5use crate::arrays::PrimitiveEncoding;
6use crate::arrays::primitive::PrimitiveArray;
7use crate::compute::CastFn;
8use crate::validity::Validity;
9use crate::variants::PrimitiveArrayTrait;
10use crate::{Array, ArrayRef};
11
12impl CastFn<&PrimitiveArray> for PrimitiveEncoding {
13 fn cast(&self, array: &PrimitiveArray, dtype: &DType) -> VortexResult<ArrayRef> {
14 let DType::Primitive(new_ptype, new_nullability) = dtype else {
15 vortex_bail!(MismatchedTypes: "primitive type", dtype);
16 };
17 let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
18
19 let new_validity = if array.dtype().nullability() == new_nullability {
21 array.validity().clone()
22 } else if new_nullability == Nullability::Nullable {
23 array.validity().clone().into_nullable()
25 } else if new_nullability == Nullability::NonNullable
26 && array.validity().to_logical(array.len())?.all_true()
27 {
28 Validity::NonNullable
30 } else {
31 vortex_bail!(
32 "invalid cast from nullable to non-nullable, since source array actually contains nulls"
33 );
34 };
35
36 if array.ptype() == new_ptype {
38 return Ok(PrimitiveArray::from_byte_buffer(
39 array.byte_buffer().clone(),
40 array.ptype(),
41 new_validity,
42 )
43 .into_array());
44 }
45
46 match_each_native_ptype!(new_ptype, |$T| {
48 Ok(PrimitiveArray::new(
49 cast::<$T>(array)?,
50 new_validity,
51 ).into_array())
52 })
53 }
54}
55
56fn cast<T: NativePType>(array: &PrimitiveArray) -> VortexResult<Buffer<T>> {
57 let mut buffer = BufferMut::with_capacity(array.len());
58 match_each_native_ptype!(array.ptype(), |$P| {
59 for item in array.as_slice::<$P>() {
60 let item = T::from(*item).ok_or_else(
61 || vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE),
62 )?;
63 unsafe { buffer.push_unchecked(item) }
65 }
66 });
67 Ok(buffer.freeze())
68}
69
70#[cfg(test)]
71mod test {
72 use vortex_buffer::buffer;
73 use vortex_dtype::{DType, Nullability, PType};
74 use vortex_error::VortexError;
75
76 use crate::IntoArray;
77 use crate::arrays::PrimitiveArray;
78 use crate::canonical::ToCanonical;
79 use crate::compute::try_cast;
80 use crate::validity::Validity;
81
82 #[test]
83 fn cast_u32_u8() {
84 let arr = buffer![0u32, 10, 200].into_array();
85
86 let p = try_cast(&arr, PType::U8.into())
88 .unwrap()
89 .to_primitive()
90 .unwrap();
91 assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
92 assert_eq!(p.validity(), &Validity::NonNullable);
93
94 let p = try_cast(&p, &DType::Primitive(PType::U8, Nullability::Nullable))
96 .unwrap()
97 .to_primitive()
98 .unwrap();
99 assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
100 assert_eq!(p.validity(), &Validity::AllValid);
101
102 let p = try_cast(&p, &DType::Primitive(PType::U8, Nullability::NonNullable))
104 .unwrap()
105 .to_primitive()
106 .unwrap();
107 assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
108 assert_eq!(p.validity(), &Validity::NonNullable);
109
110 let p = try_cast(&p, &DType::Primitive(PType::U32, Nullability::Nullable))
112 .unwrap()
113 .to_primitive()
114 .unwrap();
115 assert_eq!(p.as_slice::<u32>(), vec![0u32, 10, 200]);
116 assert_eq!(p.validity(), &Validity::AllValid);
117
118 let p = try_cast(&p, &DType::Primitive(PType::U8, Nullability::NonNullable))
120 .unwrap()
121 .to_primitive()
122 .unwrap();
123 assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
124 assert_eq!(p.validity(), &Validity::NonNullable);
125 }
126
127 #[test]
128 fn cast_u32_f32() {
129 let arr = buffer![0u32, 10, 200].into_array();
130 let u8arr = try_cast(&arr, PType::F32.into())
131 .unwrap()
132 .to_primitive()
133 .unwrap();
134 assert_eq!(u8arr.as_slice::<f32>(), vec![0.0f32, 10., 200.]);
135 }
136
137 #[test]
138 fn cast_i32_u32() {
139 let arr = buffer![-1i32].into_array();
140 let error = try_cast(&arr, PType::U32.into()).err().unwrap();
141 let VortexError::ComputeError(s, _) = error else {
142 unreachable!()
143 };
144 assert_eq!(s.to_string(), "Failed to cast -1 to U32");
145 }
146
147 #[test]
148 fn cast_array_with_nulls_to_nonnullable() {
149 let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
150 let err = try_cast(&arr, PType::I32.into()).unwrap_err();
151 let VortexError::InvalidArgument(s, _) = err else {
152 unreachable!()
153 };
154 assert_eq!(
155 s.to_string(),
156 "invalid cast from nullable to non-nullable, since source array actually contains nulls"
157 );
158 }
159}