vortex_sparse/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::compute::CastKernel;
7use vortex_array::compute::CastKernelAdapter;
8use vortex_array::compute::cast;
9use vortex_array::register_kernel;
10use vortex_dtype::DType;
11use vortex_error::VortexResult;
12
13use crate::SparseArray;
14use crate::SparseVTable;
15
16impl CastKernel for SparseVTable {
17    fn cast(&self, array: &SparseArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
18        // Cast both the patches values and the fill value
19        let casted_fill = array.fill_scalar().cast(dtype)?;
20        let casted_patches = array
21            .patches()
22            .clone()
23            .map_values(|values| cast(&values, dtype))?;
24
25        Ok(Some(
26            SparseArray::try_new_from_patches(casted_patches, casted_fill)?.into_array(),
27        ))
28    }
29}
30
31register_kernel!(CastKernelAdapter(SparseVTable).lift());
32
33#[cfg(test)]
34mod tests {
35    use rstest::rstest;
36    use vortex_array::IntoArray;
37    use vortex_array::ToCanonical;
38    use vortex_array::arrays::PrimitiveArray;
39    use vortex_array::compute::cast;
40    use vortex_array::compute::conformance::cast::test_cast_conformance;
41    use vortex_buffer::buffer;
42    use vortex_dtype::DType;
43    use vortex_dtype::Nullability;
44    use vortex_dtype::PType;
45    use vortex_scalar::Scalar;
46
47    use crate::SparseArray;
48
49    #[test]
50    fn test_cast_sparse_i32_to_i64() {
51        let sparse = SparseArray::try_new(
52            buffer![2u64, 5, 8].into_array(),
53            buffer![100i32, 200, 300].into_array(),
54            10,
55            Scalar::from(0i32),
56        )
57        .unwrap();
58
59        let casted = cast(
60            sparse.as_ref(),
61            &DType::Primitive(PType::I64, Nullability::NonNullable),
62        )
63        .unwrap();
64        assert_eq!(
65            casted.dtype(),
66            &DType::Primitive(PType::I64, Nullability::NonNullable)
67        );
68
69        let decoded = casted.to_primitive();
70        let values = decoded.as_slice::<i64>();
71        assert_eq!(values[2], 100);
72        assert_eq!(values[5], 200);
73        assert_eq!(values[8], 300);
74        assert_eq!(values[0], 0); // fill value
75    }
76
77    #[test]
78    fn test_cast_sparse_with_null_fill() {
79        let sparse = SparseArray::try_new(
80            buffer![1u64, 3, 5].into_array(),
81            PrimitiveArray::from_option_iter([Some(42i32), Some(84), Some(126)]).into_array(),
82            8,
83            Scalar::null_typed::<i32>(),
84        )
85        .unwrap();
86
87        let casted = cast(
88            sparse.as_ref(),
89            &DType::Primitive(PType::I64, Nullability::Nullable),
90        )
91        .unwrap();
92        assert_eq!(
93            casted.dtype(),
94            &DType::Primitive(PType::I64, Nullability::Nullable)
95        );
96    }
97
98    #[rstest]
99    #[case(SparseArray::try_new(
100        buffer![2u64, 5, 8].into_array(),
101        buffer![100i32, 200, 300].into_array(),
102        10,
103        Scalar::from(0i32)
104    ).unwrap())]
105    #[case(SparseArray::try_new(
106        buffer![0u64, 4, 9].into_array(),
107        buffer![1.5f32, 2.5, 3.5].into_array(),
108        10,
109        Scalar::from(0.0f32)
110    ).unwrap())]
111    #[case(SparseArray::try_new(
112        buffer![1u64, 3, 7].into_array(),
113        PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(),
114        10,
115        Scalar::null_typed::<i32>()
116    ).unwrap())]
117    #[case(SparseArray::try_new(
118        buffer![5u64].into_array(),
119        buffer![42u8].into_array(),
120        10,
121        Scalar::from(0u8)
122    ).unwrap())]
123    fn test_cast_sparse_conformance(#[case] array: SparseArray) {
124        test_cast_conformance(array.as_ref());
125    }
126}