Skip to main content

vortex_alp/alp/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::builtins::ArrayBuiltins;
7use vortex_array::dtype::DType;
8use vortex_array::patches::Patches;
9use vortex_array::scalar_fn::fns::cast::CastReduce;
10use vortex_error::VortexResult;
11
12use crate::alp::ALPArray;
13use crate::alp::ALPVTable;
14
15impl CastReduce for ALPVTable {
16    fn cast(array: &ALPArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17        // Check if this is just a nullability change
18        if array.dtype().eq_ignore_nullability(dtype) {
19            // For nullability-only changes, we can avoid decoding
20            // Cast the encoded array (integers) to handle nullability
21            let new_encoded = array.encoded().cast(
22                array
23                    .encoded()
24                    .dtype()
25                    .with_nullability(dtype.nullability()),
26            )?;
27
28            let new_patches = array
29                .patches()
30                .map(|p| {
31                    if p.values().dtype() == dtype {
32                        Ok(p.clone())
33                    } else {
34                        Patches::new(
35                            p.array_len(),
36                            p.offset(),
37                            p.indices().clone(),
38                            p.values().cast(dtype.clone())?,
39                            p.chunk_offsets().clone(),
40                        )
41                    }
42                })
43                .transpose()?;
44
45            // SAFETY: casting nullability doesn't alter the invariants
46            unsafe {
47                Ok(Some(
48                    ALPArray::new_unchecked(
49                        new_encoded,
50                        array.exponents(),
51                        new_patches,
52                        dtype.clone(),
53                    )
54                    .into_array(),
55                ))
56            }
57        } else {
58            Ok(None)
59        }
60    }
61}
62
63#[cfg(test)]
64mod tests {
65    use rstest::rstest;
66    use vortex_array::IntoArray;
67    use vortex_array::ToCanonical;
68    use vortex_array::arrays::PrimitiveArray;
69    use vortex_array::assert_arrays_eq;
70    use vortex_array::builtins::ArrayBuiltins;
71    use vortex_array::compute::conformance::cast::test_cast_conformance;
72    use vortex_array::dtype::DType;
73    use vortex_array::dtype::Nullability;
74    use vortex_array::dtype::PType;
75    use vortex_buffer::buffer;
76    use vortex_error::VortexExpect;
77    use vortex_error::VortexResult;
78
79    use crate::alp_encode;
80
81    #[test]
82    fn issue_5766_test_cast_alp_with_patches_to_nullable() -> VortexResult<()> {
83        let values = buffer![1.234f32, f32::NAN, 2.345, f32::INFINITY, 3.456].into_array();
84        let alp = alp_encode(&values.to_primitive(), None)?;
85
86        assert!(
87            alp.patches().is_some(),
88            "Test requires ALP array with patches"
89        );
90
91        let nullable_dtype = DType::Primitive(PType::F32, Nullability::Nullable);
92        let casted = alp.to_array().cast(nullable_dtype.clone())?;
93
94        let expected = values.cast(nullable_dtype)?;
95
96        assert_arrays_eq!(casted.to_canonical()?.into_primitive(), expected);
97
98        Ok(())
99    }
100
101    #[test]
102    fn test_cast_alp_f32_to_f64() -> VortexResult<()> {
103        let values = buffer![1.5f32, 2.5, 3.5, 4.5].into_array();
104        let alp = alp_encode(&values.to_primitive(), None)?;
105
106        let casted = alp
107            .to_array()
108            .cast(DType::Primitive(PType::F64, Nullability::NonNullable))?;
109        assert_eq!(
110            casted.dtype(),
111            &DType::Primitive(PType::F64, Nullability::NonNullable)
112        );
113
114        let decoded = casted.to_canonical()?.into_primitive();
115        let values = decoded.as_slice::<f64>();
116        assert_eq!(values.len(), 4);
117        assert!((values[0] - 1.5).abs() < f64::EPSILON);
118        assert!((values[1] - 2.5).abs() < f64::EPSILON);
119
120        Ok(())
121    }
122
123    #[test]
124    fn test_cast_alp_to_int() -> VortexResult<()> {
125        let values = buffer![1.0f32, 2.0, 3.0, 4.0].into_array();
126        let alp = alp_encode(&values.to_primitive(), None)?;
127
128        let casted = alp
129            .to_array()
130            .cast(DType::Primitive(PType::I32, Nullability::NonNullable))?;
131        assert_eq!(
132            casted.dtype(),
133            &DType::Primitive(PType::I32, Nullability::NonNullable)
134        );
135
136        let decoded = casted.to_canonical()?.into_primitive();
137        assert_arrays_eq!(decoded, PrimitiveArray::from_iter([1i32, 2, 3, 4]));
138
139        Ok(())
140    }
141
142    #[rstest]
143    #[case(buffer![1.23f32, 4.56, 7.89, 10.11, 12.13].into_array())]
144    #[case(buffer![100.1f64, 200.2, 300.3, 400.4, 500.5].into_array())]
145    #[case(PrimitiveArray::from_option_iter([Some(1.1f32), None, Some(2.2), Some(3.3), None]).into_array())]
146    #[case(buffer![42.42f64].into_array())]
147    #[case(buffer![0.0f32, -1.5, 2.5, -3.5, 4.5].into_array())]
148    fn test_cast_alp_conformance(#[case] array: vortex_array::ArrayRef) -> VortexResult<()> {
149        let alp = alp_encode(&array.to_primitive(), None).vortex_expect("cannot fail");
150        test_cast_conformance(&alp.to_array());
151
152        Ok(())
153    }
154}