Skip to main content

vortex_sparse/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::builtins::ArrayBuiltins;
7use vortex_array::dtype::DType;
8use vortex_array::scalar::Scalar;
9use vortex_array::scalar_fn::fns::cast::CastReduce;
10use vortex_error::VortexResult;
11
12use crate::Sparse;
13use crate::SparseArray;
14
15impl CastReduce for Sparse {
16    fn cast(array: &SparseArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17        let casted_patches = array
18            .patches()
19            .clone()
20            .map_values(|values| values.cast(dtype.clone()))?;
21
22        let casted_fill = if array.patches().num_patches() == array.len() {
23            // When every position is patched the fill scalar is unused and can be undefined.
24            // We skip casting it entirely and substitute a default value for the target dtype.
25            Scalar::default_value(dtype)
26        } else {
27            array.fill_scalar().cast(dtype)?
28        };
29
30        Ok(Some(
31            SparseArray::try_new_from_patches(casted_patches, casted_fill)?.into_array(),
32        ))
33    }
34}
35
36#[cfg(test)]
37mod tests {
38    use rstest::rstest;
39    use vortex_array::IntoArray;
40    use vortex_array::ToCanonical;
41    use vortex_array::arrays::PrimitiveArray;
42    use vortex_array::assert_arrays_eq;
43    use vortex_array::builtins::ArrayBuiltins;
44    use vortex_array::compute::conformance::cast::test_cast_conformance;
45    use vortex_array::dtype::DType;
46    use vortex_array::dtype::Nullability;
47    use vortex_array::dtype::PType;
48    use vortex_array::scalar::Scalar;
49    use vortex_buffer::buffer;
50
51    use crate::SparseArray;
52
53    #[test]
54    fn test_cast_sparse_i32_to_i64() {
55        let sparse = SparseArray::try_new(
56            buffer![2u64, 5, 8].into_array(),
57            buffer![100i32, 200, 300].into_array(),
58            10,
59            Scalar::from(0i32),
60        )
61        .unwrap();
62
63        let casted = sparse
64            .into_array()
65            .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
66            .unwrap();
67        assert_eq!(
68            casted.dtype(),
69            &DType::Primitive(PType::I64, Nullability::NonNullable)
70        );
71
72        let expected = PrimitiveArray::from_iter([0i64, 0, 100, 0, 0, 200, 0, 0, 300, 0]);
73        assert_arrays_eq!(casted.to_primitive(), expected);
74    }
75
76    #[test]
77    fn test_cast_sparse_with_null_fill() {
78        let sparse = SparseArray::try_new(
79            buffer![1u64, 3, 5].into_array(),
80            PrimitiveArray::from_option_iter([Some(42i32), Some(84), Some(126)]).into_array(),
81            8,
82            Scalar::null_native::<i32>(),
83        )
84        .unwrap();
85
86        let casted = sparse
87            .into_array()
88            .cast(DType::Primitive(PType::I64, Nullability::Nullable))
89            .unwrap();
90        assert_eq!(
91            casted.dtype(),
92            &DType::Primitive(PType::I64, Nullability::Nullable)
93        );
94    }
95
96    #[rstest]
97    #[case(SparseArray::try_new(
98        buffer![2u64, 5, 8].into_array(),
99        buffer![100i32, 200, 300].into_array(),
100        10,
101        Scalar::from(0i32)
102    ).unwrap())]
103    #[case(SparseArray::try_new(
104        buffer![0u64, 4, 9].into_array(),
105        buffer![1.5f32, 2.5, 3.5].into_array(),
106        10,
107        Scalar::from(0.0f32)
108    ).unwrap())]
109    #[case(SparseArray::try_new(
110        buffer![1u64, 3, 7].into_array(),
111        PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(),
112        10,
113        Scalar::null_native::<i32>()
114    ).unwrap())]
115    #[case(SparseArray::try_new(
116        buffer![5u64].into_array(),
117        buffer![42u8].into_array(),
118        10,
119        Scalar::from(0u8)
120    ).unwrap())]
121    fn test_cast_sparse_conformance(#[case] array: SparseArray) {
122        test_cast_conformance(&array.into_array());
123    }
124
125    #[test]
126    fn test_cast_sparse_null_fill_all_patched_to_non_nullable() -> vortex_error::VortexResult<()> {
127        // Regression test for https://github.com/vortex-data/vortex/issues/6932
128        //
129        // When all positions are patched the null fill is unused, so a cast to
130        // non-nullable is valid.  Sparse::cast detects this case, substitutes a
131        // zero fill, and keeps the result in the Sparse encoding.
132        let sparse = SparseArray::try_new(
133            buffer![0u64, 1, 2, 3, 4].into_array(),
134            buffer![10u64, 20, 30, 40, 50].into_array(),
135            5,
136            Scalar::null_native::<u64>(),
137        )?;
138
139        let casted = sparse
140            .into_array()
141            .cast(DType::Primitive(PType::U64, Nullability::NonNullable))?;
142
143        assert_eq!(
144            casted.dtype(),
145            &DType::Primitive(PType::U64, Nullability::NonNullable)
146        );
147
148        let expected = PrimitiveArray::from_iter([10u64, 20, 30, 40, 50]);
149        assert_arrays_eq!(casted.to_primitive(), expected);
150        Ok(())
151    }
152
153    #[test]
154    fn test_fill_null_sparse_with_null_fill() -> vortex_error::VortexResult<()> {
155        // Regression test for https://github.com/vortex-data/vortex/issues/6932
156        // fill_null on a sparse array with null fill triggers an internal cast to
157        // non-nullable, which must not panic.
158        let sparse = SparseArray::try_new(
159            buffer![1u64, 3].into_array(),
160            buffer![10u64, 20].into_array(),
161            5,
162            Scalar::null_native::<u64>(),
163        )?;
164
165        let filled = sparse.into_array().fill_null(Scalar::from(0u64))?;
166
167        assert_eq!(
168            filled.dtype(),
169            &DType::Primitive(PType::U64, Nullability::NonNullable)
170        );
171        Ok(())
172    }
173}