Skip to main content

vortex_alp/alp/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::IntoArray;
7use vortex_array::builtins::ArrayBuiltins;
8use vortex_array::dtype::DType;
9use vortex_array::scalar_fn::fns::cast::CastReduce;
10use vortex_error::VortexResult;
11
12use crate::ALPArrayExt;
13use crate::ALPArraySlotsExt;
14use crate::alp::ALP;
15
16impl CastReduce for ALP {
17    fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
18        // Check if this is just a nullability change
19        if !array.dtype().eq_ignore_nullability(dtype) {
20            return Ok(None);
21        }
22
23        // For nullability-only changes, we can avoid decoding
24        // Cast the encoded array (integers) to handle nullability
25        let new_encoded = array.encoded().cast(
26            array
27                .encoded()
28                .dtype()
29                .with_nullability(dtype.nullability()),
30        )?;
31
32        let new_patches = array
33            .patches()
34            .map(|p| p.map_values(|v| v.cast(dtype.clone())))
35            .transpose()?;
36
37        // SAFETY: casting nullability doesn't alter the invariants
38        unsafe {
39            Ok(Some(
40                ALP::new_unchecked(new_encoded, array.exponents(), new_patches).into_array(),
41            ))
42        }
43    }
44}
45
46#[cfg(test)]
47mod tests {
48    use rstest::rstest;
49    use vortex_array::IntoArray;
50    use vortex_array::LEGACY_SESSION;
51    use vortex_array::VortexSessionExecute;
52    use vortex_array::arrays::PrimitiveArray;
53    use vortex_array::assert_arrays_eq;
54    use vortex_array::builtins::ArrayBuiltins;
55    use vortex_array::compute::conformance::cast::test_cast_conformance;
56    use vortex_array::dtype::DType;
57    use vortex_array::dtype::Nullability;
58    use vortex_array::dtype::PType;
59    use vortex_buffer::buffer;
60    use vortex_error::VortexExpect;
61    use vortex_error::VortexResult;
62
63    use crate::alp::array::ALPArrayExt;
64    use crate::alp_encode;
65
66    #[test]
67    fn issue_5766_test_cast_alp_with_patches_to_nullable() -> VortexResult<()> {
68        let mut ctx = LEGACY_SESSION.create_execution_ctx();
69        let values = buffer![1.234f32, f32::NAN, 2.345, f32::INFINITY, 3.456].into_array();
70        let values_primitive = values.clone().execute::<PrimitiveArray>(&mut ctx)?;
71        let alp = alp_encode(values_primitive.as_view(), None, &mut ctx)?;
72
73        assert!(
74            alp.patches().is_some(),
75            "Test requires ALP array with patches"
76        );
77
78        let nullable_dtype = DType::Primitive(PType::F32, Nullability::Nullable);
79        let casted = alp.into_array().cast(nullable_dtype.clone())?;
80
81        let expected = values.cast(nullable_dtype)?;
82
83        let casted_prim = casted.execute::<PrimitiveArray>(&mut ctx)?;
84        assert_arrays_eq!(casted_prim, expected);
85
86        Ok(())
87    }
88
89    #[test]
90    fn test_cast_alp_f32_to_f64() -> VortexResult<()> {
91        let mut ctx = LEGACY_SESSION.create_execution_ctx();
92        let values = buffer![1.5f32, 2.5, 3.5, 4.5].into_array();
93        let values_primitive = values.execute::<PrimitiveArray>(&mut ctx)?;
94        let alp = alp_encode(values_primitive.as_view(), None, &mut ctx)?;
95
96        let casted = alp
97            .into_array()
98            .cast(DType::Primitive(PType::F64, Nullability::NonNullable))?;
99        assert_eq!(
100            casted.dtype(),
101            &DType::Primitive(PType::F64, Nullability::NonNullable)
102        );
103
104        let decoded = casted.execute::<PrimitiveArray>(&mut ctx)?;
105        let values = decoded.as_slice::<f64>();
106        assert_eq!(values.len(), 4);
107        assert!((values[0] - 1.5).abs() < f64::EPSILON);
108        assert!((values[1] - 2.5).abs() < f64::EPSILON);
109
110        Ok(())
111    }
112
113    #[test]
114    fn test_cast_alp_to_int() -> VortexResult<()> {
115        let mut ctx = LEGACY_SESSION.create_execution_ctx();
116        let values = buffer![1.0f32, 2.0, 3.0, 4.0].into_array();
117        let values_primitive = values.execute::<PrimitiveArray>(&mut ctx)?;
118        let alp = alp_encode(values_primitive.as_view(), None, &mut ctx)?;
119
120        let casted = alp
121            .into_array()
122            .cast(DType::Primitive(PType::I32, Nullability::NonNullable))?;
123        assert_eq!(
124            casted.dtype(),
125            &DType::Primitive(PType::I32, Nullability::NonNullable)
126        );
127
128        let decoded = casted.execute::<PrimitiveArray>(&mut ctx)?;
129        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i32, 2, 3, 4]));
130
131        Ok(())
132    }
133
134    #[rstest]
135    #[case(buffer![1.23f32, 4.56, 7.89, 10.11, 12.13].into_array())]
136    #[case(buffer![100.1f64, 200.2, 300.3, 400.4, 500.5].into_array())]
137    #[case(PrimitiveArray::from_option_iter([Some(1.1f32), None, Some(2.2), Some(3.3), None]).into_array())]
138    #[case(buffer![42.42f64].into_array())]
139    #[case(buffer![0.0f32, -1.5, 2.5, -3.5, 4.5].into_array())]
140    fn test_cast_alp_conformance(#[case] array: vortex_array::ArrayRef) -> VortexResult<()> {
141        let mut ctx = LEGACY_SESSION.create_execution_ctx();
142        let array_primitive = array.execute::<PrimitiveArray>(&mut ctx)?;
143        let alp =
144            alp_encode(array_primitive.as_view(), None, &mut ctx).vortex_expect("cannot fail");
145        test_cast_conformance(&alp.into_array());
146
147        Ok(())
148    }
149}