Skip to main content

vortex_array/arrays/primitive/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::Buffer;
5use vortex_buffer::BufferMut;
6use vortex_dtype::DType;
7use vortex_dtype::NativePType;
8use vortex_dtype::match_each_native_ptype;
9use vortex_error::VortexResult;
10use vortex_error::vortex_err;
11use vortex_mask::AllOr;
12use vortex_mask::Mask;
13
14use crate::ArrayRef;
15use crate::ExecutionCtx;
16use crate::IntoArray;
17use crate::arrays::PrimitiveVTable;
18use crate::arrays::primitive::PrimitiveArray;
19use crate::compute::CastKernel;
20use crate::vtable::ValidityHelper;
21
22impl CastKernel for PrimitiveVTable {
23    fn cast(
24        array: &PrimitiveArray,
25        dtype: &DType,
26        _ctx: &mut ExecutionCtx,
27    ) -> VortexResult<Option<ArrayRef>> {
28        let DType::Primitive(new_ptype, new_nullability) = dtype else {
29            return Ok(None);
30        };
31        let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
32
33        // First, check that the cast is compatible with the source array's validity
34        let new_validity = array
35            .validity()
36            .clone()
37            .cast_nullability(new_nullability, array.len())?;
38
39        // If the bit width is the same, we can short-circuit and simply update the validity
40        if array.ptype() == new_ptype {
41            // SAFETY: validity and data buffer still have same length
42            return Ok(Some(unsafe {
43                PrimitiveArray::new_unchecked_from_handle(
44                    array.buffer_handle().clone(),
45                    array.ptype(),
46                    new_validity,
47                )
48                .into_array()
49            }));
50        }
51
52        let mask = array.validity_mask()?;
53
54        // Otherwise, we need to cast the values one-by-one
55        Ok(Some(match_each_native_ptype!(new_ptype, |T| {
56            match_each_native_ptype!(array.ptype(), |F| {
57                PrimitiveArray::new(cast::<F, T>(array.as_slice(), mask)?, new_validity)
58                    .into_array()
59            })
60        })))
61    }
62}
63
64fn cast<F: NativePType, T: NativePType>(array: &[F], mask: Mask) -> VortexResult<Buffer<T>> {
65    match mask.bit_buffer() {
66        AllOr::All => {
67            let mut buffer = BufferMut::with_capacity(array.len());
68            for item in array {
69                let item = T::from(*item).ok_or_else(
70                    || vortex_err!(Compute: "Failed to cast {} to {:?}", item, T::PTYPE),
71                )?;
72                // SAFETY: we've pre-allocated the required capacity
73                unsafe { buffer.push_unchecked(item) }
74            }
75            Ok(buffer.freeze())
76        }
77        AllOr::None => Ok(Buffer::zeroed(array.len())),
78        AllOr::Some(b) => {
79            // TODO(robert): Depending on density of the buffer might be better to prefill Buffer and only write valid values
80            let mut buffer = BufferMut::with_capacity(array.len());
81            for (item, valid) in array.iter().zip(b.iter()) {
82                if valid {
83                    let item = T::from(*item).ok_or_else(
84                        || vortex_err!(Compute: "Failed to cast {} to {:?}", item, T::PTYPE),
85                    )?;
86                    // SAFETY: we've pre-allocated the required capacity
87                    unsafe { buffer.push_unchecked(item) }
88                } else {
89                    // SAFETY: we've pre-allocated the required capacity
90                    unsafe { buffer.push_unchecked(T::default()) }
91                }
92            }
93            Ok(buffer.freeze())
94        }
95    }
96}
97
98#[cfg(test)]
99mod test {
100    use rstest::rstest;
101    use vortex_buffer::BitBuffer;
102    use vortex_buffer::buffer;
103    use vortex_dtype::DType;
104    use vortex_dtype::Nullability;
105    use vortex_dtype::PType;
106    use vortex_error::VortexError;
107    use vortex_mask::Mask;
108
109    use crate::IntoArray;
110    use crate::arrays::PrimitiveArray;
111    use crate::assert_arrays_eq;
112    use crate::builtins::ArrayBuiltins;
113    use crate::canonical::ToCanonical;
114    use crate::compute::conformance::cast::test_cast_conformance;
115    use crate::validity::Validity;
116    use crate::vtable::ValidityHelper;
117
118    #[test]
119    fn cast_u32_u8() {
120        let arr = buffer![0u32, 10, 200].into_array();
121
122        // cast from u32 to u8
123        let p = arr.cast(PType::U8.into()).unwrap().to_primitive();
124        assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
125        assert_eq!(p.validity(), &Validity::NonNullable);
126
127        // to nullable
128        let p = p
129            .to_array()
130            .cast(DType::Primitive(PType::U8, Nullability::Nullable))
131            .unwrap()
132            .to_primitive();
133        assert_arrays_eq!(
134            p,
135            PrimitiveArray::new(buffer![0u8, 10, 200], Validity::AllValid)
136        );
137        assert_eq!(p.validity(), &Validity::AllValid);
138
139        // back to non-nullable
140        let p = p
141            .to_array()
142            .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
143            .unwrap()
144            .to_primitive();
145        assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
146        assert_eq!(p.validity(), &Validity::NonNullable);
147
148        // to nullable u32
149        let p = p
150            .to_array()
151            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
152            .unwrap()
153            .to_primitive();
154        assert_arrays_eq!(
155            p,
156            PrimitiveArray::new(buffer![0u32, 10, 200], Validity::AllValid)
157        );
158        assert_eq!(p.validity(), &Validity::AllValid);
159
160        // to non-nullable u8
161        let p = p
162            .to_array()
163            .cast(DType::Primitive(PType::U8, Nullability::NonNullable))
164            .unwrap()
165            .to_primitive();
166        assert_arrays_eq!(p, PrimitiveArray::from_iter([0u8, 10, 200]));
167        assert_eq!(p.validity(), &Validity::NonNullable);
168    }
169
170    #[test]
171    fn cast_u32_f32() {
172        let arr = buffer![0u32, 10, 200].into_array();
173        let u8arr = arr.cast(PType::F32.into()).unwrap().to_primitive();
174        assert_arrays_eq!(u8arr, PrimitiveArray::from_iter([0.0f32, 10., 200.]));
175    }
176
177    #[test]
178    fn cast_i32_u32() {
179        let arr = buffer![-1i32].into_array();
180        let error = arr
181            .cast(PType::U32.into())
182            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
183            .unwrap_err();
184        assert!(matches!(error, VortexError::Compute(..)));
185        assert!(error.to_string().contains("Failed to cast -1 to U32"));
186    }
187
188    #[test]
189    fn cast_array_with_nulls_to_nonnullable() {
190        let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
191        let err = arr
192            .to_array()
193            .cast(PType::I32.into())
194            .and_then(|a| a.to_canonical().map(|c| c.into_array()))
195            .unwrap_err();
196
197        assert!(matches!(err, VortexError::InvalidArgument(..)));
198        assert!(
199            err.to_string()
200                .contains("Cannot cast array with invalid values to non-nullable type.")
201        );
202    }
203
204    #[test]
205    fn cast_with_invalid_nulls() {
206        let arr = PrimitiveArray::new(
207            buffer![-1i32, 0, 10],
208            Validity::from_iter([false, true, true]),
209        );
210        let p = arr
211            .to_array()
212            .cast(DType::Primitive(PType::U32, Nullability::Nullable))
213            .unwrap()
214            .to_primitive();
215        assert_arrays_eq!(
216            p,
217            PrimitiveArray::from_option_iter([None, Some(0u32), Some(10)])
218        );
219        assert_eq!(
220            p.validity_mask().unwrap(),
221            Mask::from(BitBuffer::from(vec![false, true, true]))
222        );
223    }
224
225    #[rstest]
226    #[case(buffer![0u8, 1, 2, 3, 255].into_array())]
227    #[case(buffer![0u16, 100, 1000, 65535].into_array())]
228    #[case(buffer![0u32, 100, 1000, 1000000].into_array())]
229    #[case(buffer![0u64, 100, 1000, 1000000000].into_array())]
230    #[case(buffer![-128i8, -1, 0, 1, 127].into_array())]
231    #[case(buffer![-1000i16, -1, 0, 1, 1000].into_array())]
232    #[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())]
233    #[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())]
234    #[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())]
235    #[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())]
236    #[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())]
237    #[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())]
238    #[case(buffer![42u32].into_array())]
239    fn test_cast_primitive_conformance(#[case] array: crate::ArrayRef) {
240        test_cast_conformance(array.as_ref());
241    }
242}