Skip to main content

vortex_sparse/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::builtins::ArrayBuiltins;
7use vortex_array::compute::CastReduce;
8use vortex_dtype::DType;
9use vortex_error::VortexResult;
10
11use crate::SparseArray;
12use crate::SparseVTable;
13
14impl CastReduce for SparseVTable {
15    fn cast(array: &SparseArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
16        // Cast both the patches values and the fill value
17        let casted_fill = array.fill_scalar().cast(dtype)?;
18        let casted_patches = array
19            .patches()
20            .clone()
21            .map_values(|values| values.cast(dtype.clone()))?;
22
23        Ok(Some(
24            SparseArray::try_new_from_patches(casted_patches, casted_fill)?.into_array(),
25        ))
26    }
27}
28
29#[cfg(test)]
30mod tests {
31    use rstest::rstest;
32    use vortex_array::IntoArray;
33    use vortex_array::ToCanonical;
34    use vortex_array::arrays::PrimitiveArray;
35    use vortex_array::assert_arrays_eq;
36    use vortex_array::builtins::ArrayBuiltins;
37    use vortex_array::compute::conformance::cast::test_cast_conformance;
38    use vortex_array::scalar::Scalar;
39    use vortex_buffer::buffer;
40    use vortex_dtype::DType;
41    use vortex_dtype::Nullability;
42    use vortex_dtype::PType;
43
44    use crate::SparseArray;
45
46    #[test]
47    fn test_cast_sparse_i32_to_i64() {
48        let sparse = SparseArray::try_new(
49            buffer![2u64, 5, 8].into_array(),
50            buffer![100i32, 200, 300].into_array(),
51            10,
52            Scalar::from(0i32),
53        )
54        .unwrap();
55
56        let casted = sparse
57            .to_array()
58            .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
59            .unwrap();
60        assert_eq!(
61            casted.dtype(),
62            &DType::Primitive(PType::I64, Nullability::NonNullable)
63        );
64
65        let expected = PrimitiveArray::from_iter([0i64, 0, 100, 0, 0, 200, 0, 0, 300, 0]);
66        assert_arrays_eq!(casted.to_primitive(), expected);
67    }
68
69    #[test]
70    fn test_cast_sparse_with_null_fill() {
71        let sparse = SparseArray::try_new(
72            buffer![1u64, 3, 5].into_array(),
73            PrimitiveArray::from_option_iter([Some(42i32), Some(84), Some(126)]).into_array(),
74            8,
75            Scalar::null_native::<i32>(),
76        )
77        .unwrap();
78
79        let casted = sparse
80            .to_array()
81            .cast(DType::Primitive(PType::I64, Nullability::Nullable))
82            .unwrap();
83        assert_eq!(
84            casted.dtype(),
85            &DType::Primitive(PType::I64, Nullability::Nullable)
86        );
87    }
88
89    #[rstest]
90    #[case(SparseArray::try_new(
91        buffer![2u64, 5, 8].into_array(),
92        buffer![100i32, 200, 300].into_array(),
93        10,
94        Scalar::from(0i32)
95    ).unwrap())]
96    #[case(SparseArray::try_new(
97        buffer![0u64, 4, 9].into_array(),
98        buffer![1.5f32, 2.5, 3.5].into_array(),
99        10,
100        Scalar::from(0.0f32)
101    ).unwrap())]
102    #[case(SparseArray::try_new(
103        buffer![1u64, 3, 7].into_array(),
104        PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(),
105        10,
106        Scalar::null_native::<i32>()
107    ).unwrap())]
108    #[case(SparseArray::try_new(
109        buffer![5u64].into_array(),
110        buffer![42u8].into_array(),
111        10,
112        Scalar::from(0u8)
113    ).unwrap())]
114    fn test_cast_sparse_conformance(#[case] array: SparseArray) {
115        test_cast_conformance(array.as_ref());
116    }
117}