vortex_array/arrays/primitive/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::{Buffer, BufferMut};
5use vortex_dtype::{DType, NativePType, Nullability, match_each_native_ptype};
6use vortex_error::{VortexResult, vortex_bail, vortex_err};
7
8use crate::arrays::PrimitiveVTable;
9use crate::arrays::primitive::PrimitiveArray;
10use crate::compute::{CastKernel, CastKernelAdapter};
11use crate::validity::Validity;
12use crate::vtable::ValidityHelper;
13use crate::{ArrayRef, IntoArray, register_kernel};
14
15impl CastKernel for PrimitiveVTable {
16    fn cast(&self, array: &PrimitiveArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17        let DType::Primitive(new_ptype, new_nullability) = dtype else {
18            return Ok(None);
19        };
20        let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
21
22        // First, check that the cast is compatible with the source array's validity
23        let new_validity = if array.dtype().nullability() == new_nullability {
24            array.validity().clone()
25        } else if new_nullability == Nullability::Nullable {
26            // from non-nullable to nullable
27            array.validity().clone().into_nullable()
28        } else if new_nullability == Nullability::NonNullable && array.validity().all_valid() {
29            // from nullable but all valid, to non-nullable
30            Validity::NonNullable
31        } else {
32            vortex_bail!(
33                "invalid cast from nullable to non-nullable, since source array actually contains nulls"
34            );
35        };
36
37        // If the bit width is the same, we can short-circuit and simply update the validity
38        if array.ptype() == new_ptype {
39            return Ok(Some(
40                PrimitiveArray::from_byte_buffer(
41                    array.byte_buffer().clone(),
42                    array.ptype(),
43                    new_validity,
44                )
45                .into_array(),
46            ));
47        }
48
49        // Otherwise, we need to cast the values one-by-one
50        match_each_native_ptype!(new_ptype, |T| {
51            Ok(Some(
52                PrimitiveArray::new(cast::<T>(array)?, new_validity).into_array(),
53            ))
54        })
55    }
56}
57
58register_kernel!(CastKernelAdapter(PrimitiveVTable).lift());
59
60fn cast<T: NativePType>(array: &PrimitiveArray) -> VortexResult<Buffer<T>> {
61    let mut buffer = BufferMut::with_capacity(array.len());
62    match_each_native_ptype!(array.ptype(), |P| {
63        for item in array.as_slice::<P>() {
64            let item = T::from(*item).ok_or_else(
65                || vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE),
66            )?;
67            // SAFETY: we've pre-allocated the required capacity
68            unsafe { buffer.push_unchecked(item) }
69        }
70    });
71    Ok(buffer.freeze())
72}
73
74#[cfg(test)]
75mod test {
76    use rstest::rstest;
77    use vortex_buffer::buffer;
78    use vortex_dtype::{DType, Nullability, PType};
79    use vortex_error::VortexError;
80
81    use crate::IntoArray;
82    use crate::arrays::PrimitiveArray;
83    use crate::canonical::ToCanonical;
84    use crate::compute::cast;
85    use crate::compute::conformance::cast::test_cast_conformance;
86    use crate::validity::Validity;
87    use crate::vtable::ValidityHelper;
88
89    #[test]
90    fn cast_u32_u8() {
91        let arr = buffer![0u32, 10, 200].into_array();
92
93        // cast from u32 to u8
94        let p = cast(&arr, PType::U8.into()).unwrap().to_primitive();
95        assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
96        assert_eq!(p.validity(), &Validity::NonNullable);
97
98        // to nullable
99        let p = cast(
100            p.as_ref(),
101            &DType::Primitive(PType::U8, Nullability::Nullable),
102        )
103        .unwrap()
104        .to_primitive();
105        assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
106        assert_eq!(p.validity(), &Validity::AllValid);
107
108        // back to non-nullable
109        let p = cast(
110            p.as_ref(),
111            &DType::Primitive(PType::U8, Nullability::NonNullable),
112        )
113        .unwrap()
114        .to_primitive();
115        assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
116        assert_eq!(p.validity(), &Validity::NonNullable);
117
118        // to nullable u32
119        let p = cast(
120            p.as_ref(),
121            &DType::Primitive(PType::U32, Nullability::Nullable),
122        )
123        .unwrap()
124        .to_primitive();
125        assert_eq!(p.as_slice::<u32>(), vec![0u32, 10, 200]);
126        assert_eq!(p.validity(), &Validity::AllValid);
127
128        // to non-nullable u8
129        let p = cast(
130            p.as_ref(),
131            &DType::Primitive(PType::U8, Nullability::NonNullable),
132        )
133        .unwrap()
134        .to_primitive();
135        assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
136        assert_eq!(p.validity(), &Validity::NonNullable);
137    }
138
139    #[test]
140    fn cast_u32_f32() {
141        let arr = buffer![0u32, 10, 200].into_array();
142        let u8arr = cast(&arr, PType::F32.into()).unwrap().to_primitive();
143        assert_eq!(u8arr.as_slice::<f32>(), vec![0.0f32, 10., 200.]);
144    }
145
146    #[test]
147    fn cast_i32_u32() {
148        let arr = buffer![-1i32].into_array();
149        let error = cast(&arr, PType::U32.into()).err().unwrap();
150        let VortexError::ComputeError(s, _) = error else {
151            unreachable!()
152        };
153        assert_eq!(s.to_string(), "Failed to cast -1 to U32");
154    }
155
156    #[test]
157    fn cast_array_with_nulls_to_nonnullable() {
158        let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
159        let err = cast(arr.as_ref(), PType::I32.into()).unwrap_err();
160        let VortexError::InvalidArgument(s, _) = err else {
161            unreachable!()
162        };
163        assert_eq!(
164            s.to_string(),
165            "invalid cast from nullable to non-nullable, since source array actually contains nulls"
166        );
167    }
168
169    #[rstest]
170    #[case(buffer![0u8, 1, 2, 3, 255].into_array())]
171    #[case(buffer![0u16, 100, 1000, 65535].into_array())]
172    #[case(buffer![0u32, 100, 1000, 1000000].into_array())]
173    #[case(buffer![0u64, 100, 1000, 1000000000].into_array())]
174    #[case(buffer![-128i8, -1, 0, 1, 127].into_array())]
175    #[case(buffer![-1000i16, -1, 0, 1, 1000].into_array())]
176    #[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())]
177    #[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())]
178    #[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())]
179    #[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())]
180    #[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())]
181    #[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())]
182    #[case(buffer![42u32].into_array())]
183    fn test_cast_primitive_conformance(#[case] array: crate::ArrayRef) {
184        test_cast_conformance(array.as_ref());
185    }
186}