vortex_sparse/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::compute::{CastKernel, CastKernelAdapter, cast};
5use vortex_array::{ArrayRef, IntoArray, register_kernel};
6use vortex_dtype::DType;
7use vortex_error::VortexResult;
8
9use crate::{SparseArray, SparseVTable};
10
11impl CastKernel for SparseVTable {
12    fn cast(&self, array: &SparseArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
13        // Cast both the patches values and the fill value
14        let casted_fill = array.fill_scalar().cast(dtype)?;
15        let casted_patches = array
16            .patches()
17            .clone()
18            .map_values(|values| cast(&values, dtype))?;
19
20        Ok(Some(
21            SparseArray::try_new_from_patches(casted_patches, casted_fill)?.into_array(),
22        ))
23    }
24}
25
26register_kernel!(CastKernelAdapter(SparseVTable).lift());
27
28#[cfg(test)]
29mod tests {
30    use rstest::rstest;
31    use vortex_array::arrays::PrimitiveArray;
32    use vortex_array::compute::cast;
33    use vortex_array::compute::conformance::cast::test_cast_conformance;
34    use vortex_array::{IntoArray, ToCanonical};
35    use vortex_buffer::buffer;
36    use vortex_dtype::{DType, Nullability, PType};
37    use vortex_scalar::Scalar;
38
39    use crate::SparseArray;
40
41    #[test]
42    fn test_cast_sparse_i32_to_i64() {
43        let sparse = SparseArray::try_new(
44            buffer![2u64, 5, 8].into_array(),
45            buffer![100i32, 200, 300].into_array(),
46            10,
47            Scalar::from(0i32),
48        )
49        .unwrap();
50
51        let casted = cast(
52            sparse.as_ref(),
53            &DType::Primitive(PType::I64, Nullability::NonNullable),
54        )
55        .unwrap();
56        assert_eq!(
57            casted.dtype(),
58            &DType::Primitive(PType::I64, Nullability::NonNullable)
59        );
60
61        let decoded = casted.to_primitive();
62        let values = decoded.as_slice::<i64>();
63        assert_eq!(values[2], 100);
64        assert_eq!(values[5], 200);
65        assert_eq!(values[8], 300);
66        assert_eq!(values[0], 0); // fill value
67    }
68
69    #[test]
70    fn test_cast_sparse_with_null_fill() {
71        let sparse = SparseArray::try_new(
72            buffer![1u64, 3, 5].into_array(),
73            PrimitiveArray::from_option_iter([Some(42i32), Some(84), Some(126)]).into_array(),
74            8,
75            Scalar::null_typed::<i32>(),
76        )
77        .unwrap();
78
79        let casted = cast(
80            sparse.as_ref(),
81            &DType::Primitive(PType::I64, Nullability::Nullable),
82        )
83        .unwrap();
84        assert_eq!(
85            casted.dtype(),
86            &DType::Primitive(PType::I64, Nullability::Nullable)
87        );
88    }
89
90    #[rstest]
91    #[case(SparseArray::try_new(
92        buffer![2u64, 5, 8].into_array(),
93        buffer![100i32, 200, 300].into_array(),
94        10,
95        Scalar::from(0i32)
96    ).unwrap())]
97    #[case(SparseArray::try_new(
98        buffer![0u64, 4, 9].into_array(),
99        buffer![1.5f32, 2.5, 3.5].into_array(),
100        10,
101        Scalar::from(0.0f32)
102    ).unwrap())]
103    #[case(SparseArray::try_new(
104        buffer![1u64, 3, 7].into_array(),
105        PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(),
106        10,
107        Scalar::null_typed::<i32>()
108    ).unwrap())]
109    #[case(SparseArray::try_new(
110        buffer![5u64].into_array(),
111        buffer![42u8].into_array(),
112        10,
113        Scalar::from(0u8)
114    ).unwrap())]
115    fn test_cast_sparse_conformance(#[case] array: SparseArray) {
116        test_cast_conformance(array.as_ref());
117    }
118}