Skip to main content

vortex_sparse/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::IntoArray;
7use vortex_array::builtins::ArrayBuiltins;
8use vortex_array::dtype::DType;
9use vortex_array::scalar::Scalar;
10use vortex_array::scalar_fn::fns::cast::CastReduce;
11use vortex_error::VortexResult;
12
13use crate::Sparse;
14
15impl CastReduce for Sparse {
16    fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17        let casted_patches = array
18            .patches()
19            .clone()
20            .map_values(|values| values.cast(dtype.clone()))?;
21
22        let casted_fill = if array.patches().num_patches() == array.len() {
23            // When every position is patched the fill scalar is unused and can be undefined.
24            // We skip casting it entirely and substitute a default value for the target dtype.
25            Scalar::default_value(dtype)
26        } else {
27            array.fill_scalar().cast(dtype)?
28        };
29
30        Ok(Some(
31            Sparse::try_new_from_patches(casted_patches, casted_fill)?.into_array(),
32        ))
33    }
34}
35
36#[cfg(test)]
37mod tests {
38    use std::sync::LazyLock;
39
40    use rstest::rstest;
41    use vortex_array::IntoArray;
42    use vortex_array::VortexSessionExecute;
43    use vortex_array::arrays::PrimitiveArray;
44    use vortex_array::assert_arrays_eq;
45    use vortex_array::builtins::ArrayBuiltins;
46    use vortex_array::compute::conformance::cast::test_cast_conformance;
47    use vortex_array::dtype::DType;
48    use vortex_array::dtype::Nullability;
49    use vortex_array::dtype::PType;
50    use vortex_array::scalar::Scalar;
51    use vortex_array::session::ArraySession;
52    use vortex_buffer::buffer;
53    use vortex_session::VortexSession;
54
55    use crate::Sparse;
56    use crate::SparseArray;
57
58    static SESSION: LazyLock<VortexSession> =
59        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
60
61    #[test]
62    fn test_cast_sparse_i32_to_i64() {
63        let mut ctx = SESSION.create_execution_ctx();
64        let sparse = Sparse::try_new(
65            buffer![2u64, 5, 8].into_array(),
66            buffer![100i32, 200, 300].into_array(),
67            10,
68            Scalar::from(0i32),
69        )
70        .unwrap();
71
72        let casted = sparse
73            .into_array()
74            .cast(DType::Primitive(PType::I64, Nullability::NonNullable))
75            .unwrap();
76        assert_eq!(
77            casted.dtype(),
78            &DType::Primitive(PType::I64, Nullability::NonNullable)
79        );
80
81        let expected = PrimitiveArray::from_iter([0i64, 0, 100, 0, 0, 200, 0, 0, 300, 0]);
82        let casted_primitive = casted.execute::<PrimitiveArray>(&mut ctx).unwrap();
83        assert_arrays_eq!(casted_primitive, expected);
84    }
85
86    #[test]
87    fn test_cast_sparse_with_null_fill() {
88        let sparse = Sparse::try_new(
89            buffer![1u64, 3, 5].into_array(),
90            PrimitiveArray::from_option_iter([Some(42i32), Some(84), Some(126)]).into_array(),
91            8,
92            Scalar::null_native::<i32>(),
93        )
94        .unwrap();
95
96        let casted = sparse
97            .into_array()
98            .cast(DType::Primitive(PType::I64, Nullability::Nullable))
99            .unwrap();
100        assert_eq!(
101            casted.dtype(),
102            &DType::Primitive(PType::I64, Nullability::Nullable)
103        );
104    }
105
106    #[rstest]
107    #[case(Sparse::try_new(
108        buffer![2u64, 5, 8].into_array(),
109        buffer![100i32, 200, 300].into_array(),
110        10,
111        Scalar::from(0i32)
112    ).unwrap())]
113    #[case(Sparse::try_new(
114        buffer![0u64, 4, 9].into_array(),
115        buffer![1.5f32, 2.5, 3.5].into_array(),
116        10,
117        Scalar::from(0.0f32)
118    ).unwrap())]
119    #[case(Sparse::try_new(
120        buffer![1u64, 3, 7].into_array(),
121        PrimitiveArray::from_option_iter([Some(100i32), None, Some(300)]).into_array(),
122        10,
123        Scalar::null_native::<i32>()
124    ).unwrap())]
125    #[case(Sparse::try_new(
126        buffer![5u64].into_array(),
127        buffer![42u8].into_array(),
128        10,
129        Scalar::from(0u8)
130    ).unwrap())]
131    fn test_cast_sparse_conformance(#[case] array: SparseArray) {
132        test_cast_conformance(&array.into_array());
133    }
134
135    #[test]
136    fn test_cast_sparse_null_fill_all_patched_to_non_nullable() -> vortex_error::VortexResult<()> {
137        let mut ctx = SESSION.create_execution_ctx();
138        // Regression test for https://github.com/vortex-data/vortex/issues/6932
139        //
140        // When all positions are patched the null fill is unused, so a cast to
141        // non-nullable is valid.  Sparse::cast detects this case, substitutes a
142        // zero fill, and keeps the result in the Sparse encoding.
143        let sparse = Sparse::try_new(
144            buffer![0u64, 1, 2, 3, 4].into_array(),
145            buffer![10u64, 20, 30, 40, 50].into_array(),
146            5,
147            Scalar::null_native::<u64>(),
148        )?;
149
150        let casted = sparse
151            .into_array()
152            .cast(DType::Primitive(PType::U64, Nullability::NonNullable))?;
153
154        assert_eq!(
155            casted.dtype(),
156            &DType::Primitive(PType::U64, Nullability::NonNullable)
157        );
158
159        let expected = PrimitiveArray::from_iter([10u64, 20, 30, 40, 50]);
160        let casted_primitive = casted.execute::<PrimitiveArray>(&mut ctx)?;
161        assert_arrays_eq!(casted_primitive, expected);
162        Ok(())
163    }
164
165    #[test]
166    fn test_fill_null_sparse_with_null_fill() -> vortex_error::VortexResult<()> {
167        // Regression test for https://github.com/vortex-data/vortex/issues/6932
168        // fill_null on a sparse array with null fill triggers an internal cast to
169        // non-nullable, which must not panic.
170        let sparse = Sparse::try_new(
171            buffer![1u64, 3].into_array(),
172            buffer![10u64, 20].into_array(),
173            5,
174            Scalar::null_native::<u64>(),
175        )?;
176
177        let filled = sparse.into_array().fill_null(Scalar::from(0u64))?;
178
179        assert_eq!(
180            filled.dtype(),
181            &DType::Primitive(PType::U64, Nullability::NonNullable)
182        );
183        Ok(())
184    }
185}