Skip to main content

vortex_alp/alp/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::IntoArray;
7use vortex_array::builtins::ArrayBuiltins;
8use vortex_array::dtype::DType;
9use vortex_array::patches::Patches;
10use vortex_array::scalar_fn::fns::cast::CastReduce;
11use vortex_error::VortexResult;
12
13use crate::ALPArrayExt;
14use crate::ALPArraySlotsExt;
15use crate::alp::ALP;
16
17impl CastReduce for ALP {
18    fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
19        // Check if this is just a nullability change
20        if array.dtype().eq_ignore_nullability(dtype) {
21            // For nullability-only changes, we can avoid decoding
22            // Cast the encoded array (integers) to handle nullability
23            let new_encoded = array.encoded().cast(
24                array
25                    .encoded()
26                    .dtype()
27                    .with_nullability(dtype.nullability()),
28            )?;
29
30            let new_patches = array
31                .patches()
32                .map(|p| {
33                    if p.values().dtype() == dtype {
34                        Ok(p)
35                    } else {
36                        Patches::new(
37                            p.array_len(),
38                            p.offset(),
39                            p.indices().clone(),
40                            p.values().cast(dtype.clone())?,
41                            p.chunk_offsets().clone(),
42                        )
43                    }
44                })
45                .transpose()?;
46
47            // SAFETY: casting nullability doesn't alter the invariants
48            unsafe {
49                Ok(Some(
50                    ALP::new_unchecked(new_encoded, array.exponents(), new_patches).into_array(),
51                ))
52            }
53        } else {
54            Ok(None)
55        }
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use rstest::rstest;
62    use vortex_array::IntoArray;
63    use vortex_array::LEGACY_SESSION;
64    use vortex_array::ToCanonical;
65    use vortex_array::VortexSessionExecute;
66    use vortex_array::arrays::PrimitiveArray;
67    use vortex_array::assert_arrays_eq;
68    use vortex_array::builtins::ArrayBuiltins;
69    use vortex_array::compute::conformance::cast::test_cast_conformance;
70    use vortex_array::dtype::DType;
71    use vortex_array::dtype::Nullability;
72    use vortex_array::dtype::PType;
73    use vortex_buffer::buffer;
74    use vortex_error::VortexExpect;
75    use vortex_error::VortexResult;
76
77    use crate::alp::array::ALPArrayExt;
78    use crate::alp_encode;
79
80    #[test]
81    fn issue_5766_test_cast_alp_with_patches_to_nullable() -> VortexResult<()> {
82        let values = buffer![1.234f32, f32::NAN, 2.345, f32::INFINITY, 3.456].into_array();
83        let alp = alp_encode(
84            values.to_primitive().as_view(),
85            None,
86            &mut LEGACY_SESSION.create_execution_ctx(),
87        )?;
88
89        assert!(
90            alp.patches().is_some(),
91            "Test requires ALP array with patches"
92        );
93
94        let nullable_dtype = DType::Primitive(PType::F32, Nullability::Nullable);
95        let casted = alp.into_array().cast(nullable_dtype.clone())?;
96
97        let expected = values.cast(nullable_dtype)?;
98
99        assert_arrays_eq!(casted.to_canonical()?.into_primitive(), expected);
100
101        Ok(())
102    }
103
104    #[test]
105    fn test_cast_alp_f32_to_f64() -> VortexResult<()> {
106        let values = buffer![1.5f32, 2.5, 3.5, 4.5].into_array();
107        let alp = alp_encode(
108            values.to_primitive().as_view(),
109            None,
110            &mut LEGACY_SESSION.create_execution_ctx(),
111        )?;
112
113        let casted = alp
114            .into_array()
115            .cast(DType::Primitive(PType::F64, Nullability::NonNullable))?;
116        assert_eq!(
117            casted.dtype(),
118            &DType::Primitive(PType::F64, Nullability::NonNullable)
119        );
120
121        let decoded = casted.to_canonical()?.into_primitive();
122        let values = decoded.as_slice::<f64>();
123        assert_eq!(values.len(), 4);
124        assert!((values[0] - 1.5).abs() < f64::EPSILON);
125        assert!((values[1] - 2.5).abs() < f64::EPSILON);
126
127        Ok(())
128    }
129
130    #[test]
131    fn test_cast_alp_to_int() -> VortexResult<()> {
132        let values = buffer![1.0f32, 2.0, 3.0, 4.0].into_array();
133        let alp = alp_encode(
134            values.to_primitive().as_view(),
135            None,
136            &mut LEGACY_SESSION.create_execution_ctx(),
137        )?;
138
139        let casted = alp
140            .into_array()
141            .cast(DType::Primitive(PType::I32, Nullability::NonNullable))?;
142        assert_eq!(
143            casted.dtype(),
144            &DType::Primitive(PType::I32, Nullability::NonNullable)
145        );
146
147        let decoded = casted.to_canonical()?.into_primitive();
148        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i32, 2, 3, 4]));
149
150        Ok(())
151    }
152
153    #[rstest]
154    #[case(buffer![1.23f32, 4.56, 7.89, 10.11, 12.13].into_array())]
155    #[case(buffer![100.1f64, 200.2, 300.3, 400.4, 500.5].into_array())]
156    #[case(PrimitiveArray::from_option_iter([Some(1.1f32), None, Some(2.2), Some(3.3), None]).into_array())]
157    #[case(buffer![42.42f64].into_array())]
158    #[case(buffer![0.0f32, -1.5, 2.5, -3.5, 4.5].into_array())]
159    fn test_cast_alp_conformance(#[case] array: vortex_array::ArrayRef) -> VortexResult<()> {
160        let alp = alp_encode(
161            array.to_primitive().as_view(),
162            None,
163            &mut LEGACY_SESSION.create_execution_ctx(),
164        )
165        .vortex_expect("cannot fail");
166        test_cast_conformance(&alp.into_array());
167
168        Ok(())
169    }
170}