vortex_array/arrays/primitive/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_buffer::Buffer;
5use vortex_buffer::BufferMut;
6use vortex_dtype::DType;
7use vortex_dtype::NativePType;
8use vortex_dtype::match_each_native_ptype;
9use vortex_error::VortexResult;
10use vortex_error::vortex_err;
11use vortex_mask::AllOr;
12use vortex_mask::Mask;
13
14use crate::ArrayRef;
15use crate::IntoArray;
16use crate::arrays::PrimitiveVTable;
17use crate::arrays::primitive::PrimitiveArray;
18use crate::compute::CastKernel;
19use crate::compute::CastKernelAdapter;
20use crate::register_kernel;
21use crate::vtable::ValidityHelper;
22
23impl CastKernel for PrimitiveVTable {
24    fn cast(&self, array: &PrimitiveArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
25        let DType::Primitive(new_ptype, new_nullability) = dtype else {
26            return Ok(None);
27        };
28        let (new_ptype, new_nullability) = (*new_ptype, *new_nullability);
29
30        // First, check that the cast is compatible with the source array's validity
31        let new_validity = array
32            .validity()
33            .clone()
34            .cast_nullability(new_nullability, array.len())?;
35
36        // If the bit width is the same, we can short-circuit and simply update the validity
37        if array.ptype() == new_ptype {
38            return Ok(Some(
39                PrimitiveArray::from_byte_buffer(
40                    array.byte_buffer().clone(),
41                    array.ptype(),
42                    new_validity,
43                )
44                .into_array(),
45            ));
46        }
47
48        let mask = array.validity_mask();
49
50        // Otherwise, we need to cast the values one-by-one
51        Ok(Some(match_each_native_ptype!(new_ptype, |T| {
52            match_each_native_ptype!(array.ptype(), |F| {
53                PrimitiveArray::new(cast::<F, T>(array.as_slice(), mask)?, new_validity)
54                    .into_array()
55            })
56        })))
57    }
58}
59
60register_kernel!(CastKernelAdapter(PrimitiveVTable).lift());
61
62fn cast<F: NativePType, T: NativePType>(array: &[F], mask: Mask) -> VortexResult<Buffer<T>> {
63    match mask.bit_buffer() {
64        AllOr::All => {
65            let mut buffer = BufferMut::with_capacity(array.len());
66            for item in array {
67                let item = T::from(*item).ok_or_else(
68                    || vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE),
69                )?;
70                // SAFETY: we've pre-allocated the required capacity
71                unsafe { buffer.push_unchecked(item) }
72            }
73            Ok(buffer.freeze())
74        }
75        AllOr::None => Ok(Buffer::zeroed(array.len())),
76        AllOr::Some(b) => {
77            // TODO(robert): Depending on density of the buffer might be better to prefill Buffer and only write valid values
78            let mut buffer = BufferMut::with_capacity(array.len());
79            for (item, valid) in array.iter().zip(b.iter()) {
80                if valid {
81                    let item = T::from(*item).ok_or_else(
82                        || vortex_err!(ComputeError: "Failed to cast {} to {:?}", item, T::PTYPE),
83                    )?;
84                    // SAFETY: we've pre-allocated the required capacity
85                    unsafe { buffer.push_unchecked(item) }
86                } else {
87                    // SAFETY: we've pre-allocated the required capacity
88                    unsafe { buffer.push_unchecked(T::default()) }
89                }
90            }
91            Ok(buffer.freeze())
92        }
93    }
94}
95
96#[cfg(test)]
97mod test {
98    use rstest::rstest;
99    use vortex_buffer::BitBuffer;
100    use vortex_buffer::buffer;
101    use vortex_dtype::DType;
102    use vortex_dtype::Nullability;
103    use vortex_dtype::PType;
104    use vortex_error::VortexError;
105    use vortex_mask::Mask;
106
107    use crate::IntoArray;
108    use crate::arrays::PrimitiveArray;
109    use crate::canonical::ToCanonical;
110    use crate::compute::cast;
111    use crate::compute::conformance::cast::test_cast_conformance;
112    use crate::validity::Validity;
113    use crate::vtable::ValidityHelper;
114
115    #[test]
116    fn cast_u32_u8() {
117        let arr = buffer![0u32, 10, 200].into_array();
118
119        // cast from u32 to u8
120        let p = cast(&arr, PType::U8.into()).unwrap().to_primitive();
121        assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
122        assert_eq!(p.validity(), &Validity::NonNullable);
123
124        // to nullable
125        let p = cast(
126            p.as_ref(),
127            &DType::Primitive(PType::U8, Nullability::Nullable),
128        )
129        .unwrap()
130        .to_primitive();
131        assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
132        assert_eq!(p.validity(), &Validity::AllValid);
133
134        // back to non-nullable
135        let p = cast(
136            p.as_ref(),
137            &DType::Primitive(PType::U8, Nullability::NonNullable),
138        )
139        .unwrap()
140        .to_primitive();
141        assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
142        assert_eq!(p.validity(), &Validity::NonNullable);
143
144        // to nullable u32
145        let p = cast(
146            p.as_ref(),
147            &DType::Primitive(PType::U32, Nullability::Nullable),
148        )
149        .unwrap()
150        .to_primitive();
151        assert_eq!(p.as_slice::<u32>(), vec![0u32, 10, 200]);
152        assert_eq!(p.validity(), &Validity::AllValid);
153
154        // to non-nullable u8
155        let p = cast(
156            p.as_ref(),
157            &DType::Primitive(PType::U8, Nullability::NonNullable),
158        )
159        .unwrap()
160        .to_primitive();
161        assert_eq!(p.as_slice::<u8>(), vec![0u8, 10, 200]);
162        assert_eq!(p.validity(), &Validity::NonNullable);
163    }
164
165    #[test]
166    fn cast_u32_f32() {
167        let arr = buffer![0u32, 10, 200].into_array();
168        let u8arr = cast(&arr, PType::F32.into()).unwrap().to_primitive();
169        assert_eq!(u8arr.as_slice::<f32>(), vec![0.0f32, 10., 200.]);
170    }
171
172    #[test]
173    fn cast_i32_u32() {
174        let arr = buffer![-1i32].into_array();
175        let error = cast(&arr, PType::U32.into()).err().unwrap();
176        let VortexError::ComputeError(s, _) = error else {
177            unreachable!()
178        };
179        assert_eq!(s.to_string(), "Failed to cast -1 to U32");
180    }
181
182    #[test]
183    fn cast_array_with_nulls_to_nonnullable() {
184        let arr = PrimitiveArray::from_option_iter([Some(-1i32), None, Some(10)]);
185        let err = cast(arr.as_ref(), PType::I32.into()).unwrap_err();
186        let VortexError::InvalidArgument(s, _) = err else {
187            unreachable!()
188        };
189        assert_eq!(
190            s.to_string(),
191            "Cannot cast array with invalid values to non-nullable type."
192        );
193    }
194
195    #[test]
196    fn cast_with_invalid_nulls() {
197        let arr = PrimitiveArray::new(
198            buffer![-1i32, 0, 10],
199            Validity::from_iter([false, true, true]),
200        );
201        let p = cast(
202            arr.as_ref(),
203            &DType::Primitive(PType::U32, Nullability::Nullable),
204        )
205        .unwrap()
206        .to_primitive();
207        assert_eq!(p.as_slice::<u32>(), vec![0, 0, 10]);
208        assert_eq!(
209            p.validity_mask(),
210            Mask::from(BitBuffer::from(vec![false, true, true]))
211        );
212    }
213
214    #[rstest]
215    #[case(buffer![0u8, 1, 2, 3, 255].into_array())]
216    #[case(buffer![0u16, 100, 1000, 65535].into_array())]
217    #[case(buffer![0u32, 100, 1000, 1000000].into_array())]
218    #[case(buffer![0u64, 100, 1000, 1000000000].into_array())]
219    #[case(buffer![-128i8, -1, 0, 1, 127].into_array())]
220    #[case(buffer![-1000i16, -1, 0, 1, 1000].into_array())]
221    #[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())]
222    #[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())]
223    #[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())]
224    #[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())]
225    #[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())]
226    #[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())]
227    #[case(buffer![42u32].into_array())]
228    fn test_cast_primitive_conformance(#[case] array: crate::ArrayRef) {
229        test_cast_conformance(array.as_ref());
230    }
231}