Skip to main content

vortex_alp/alp_rd/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::IntoArray;
7use vortex_array::builtins::ArrayBuiltins;
8use vortex_array::dtype::DType;
9use vortex_array::scalar_fn::fns::cast::CastReduce;
10use vortex_error::VortexResult;
11
12use crate::ALPRDArrayExt;
13use crate::alp_rd::ALPRD;
14
15impl CastReduce for ALPRD {
16    fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17        // ALPRDArray stores floating-point values, so only cast between float types
18        // or if just changing nullability
19
20        // Check if this is just a nullability change
21        if array.dtype().eq_ignore_nullability(dtype) {
22            // For nullability-only changes, we need to cast the left_parts array
23            // since it carries the validity information
24            let new_left_parts = array.left_parts().cast(
25                array
26                    .left_parts()
27                    .dtype()
28                    .with_nullability(dtype.nullability()),
29            )?;
30
31            return Ok(Some(
32                ALPRD::try_new(
33                    dtype.clone(),
34                    new_left_parts,
35                    array.left_parts_dictionary().clone(),
36                    array.right_parts().clone(),
37                    array.right_bit_width(),
38                    array.left_parts_patches(),
39                )?
40                .into_array(),
41            ));
42        }
43
44        // For other casts (e.g., f32 to f64), decode to canonical and let PrimitiveArray handle it
45        Ok(None)
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use rstest::rstest;
52    use vortex_array::IntoArray;
53    use vortex_array::ToCanonical;
54    use vortex_array::arrays::PrimitiveArray;
55    use vortex_array::builtins::ArrayBuiltins;
56    use vortex_array::compute::conformance::cast::test_cast_conformance;
57    use vortex_array::dtype::DType;
58    use vortex_array::dtype::Nullability;
59    use vortex_array::dtype::PType;
60
61    use crate::RDEncoder;
62
63    #[test]
64    fn test_cast_alprd_f32_to_f64() {
65        let values = vec![1.0f32, 1.1, 1.2, 1.3, 1.4];
66        let arr = PrimitiveArray::from_iter(values.clone());
67        let encoder = RDEncoder::new(&values);
68        let alprd = encoder.encode(&arr);
69
70        let casted = alprd
71            .into_array()
72            .cast(DType::Primitive(PType::F64, Nullability::NonNullable))
73            .unwrap();
74        assert_eq!(
75            casted.dtype(),
76            &DType::Primitive(PType::F64, Nullability::NonNullable)
77        );
78
79        let decoded = casted.to_primitive();
80        let f64_values = decoded.as_slice::<f64>();
81        assert_eq!(f64_values.len(), 5);
82        assert!((f64_values[0] - 1.0).abs() < f64::EPSILON);
83        assert!((f64_values[1] - 1.1).abs() < 1e-6); // Use larger epsilon for f32->f64 conversion
84    }
85
86    #[test]
87    fn test_cast_alprd_nullable() {
88        let arr =
89            PrimitiveArray::from_option_iter([Some(10.0f64), None, Some(10.1), Some(10.2), None]);
90        let values = vec![10.0f64, 10.1, 10.2];
91        let encoder = RDEncoder::new(&values);
92        let alprd = encoder.encode(&arr);
93
94        // Cast to NonNullable should fail since we have nulls
95        let result = alprd
96            .clone()
97            .into_array()
98            .cast(DType::Primitive(PType::F64, Nullability::NonNullable));
99        assert!(result.is_err());
100
101        // Cast to same type with Nullable should succeed
102        let casted = alprd
103            .into_array()
104            .cast(DType::Primitive(PType::F64, Nullability::Nullable))
105            .unwrap();
106        assert_eq!(
107            casted.dtype(),
108            &DType::Primitive(PType::F64, Nullability::Nullable)
109        );
110    }
111
112    #[rstest]
113    #[case::f32({
114        let values = vec![1.23f32, 4.56, 7.89, 10.11, 12.13];
115        let arr = PrimitiveArray::from_iter(values.clone());
116        let encoder = RDEncoder::new(&values);
117        encoder.encode(&arr)
118    })]
119    #[case::f64({
120        let values = vec![100.1f64, 200.2, 300.3, 400.4, 500.5];
121        let arr = PrimitiveArray::from_iter(values.clone());
122        let encoder = RDEncoder::new(&values);
123        encoder.encode(&arr)
124    })]
125    #[case::single({
126        let values = vec![42.42f64];
127        let arr = PrimitiveArray::from_iter(values.clone());
128        let encoder = RDEncoder::new(&values);
129        encoder.encode(&arr)
130    })]
131    #[case::negative({
132        let values = vec![0.0f32, -1.5, 2.5, -3.5, 4.5];
133        let arr = PrimitiveArray::from_iter(values.clone());
134        let encoder = RDEncoder::new(&values);
135        encoder.encode(&arr)
136    })]
137    #[case::nullable({
138        let arr = PrimitiveArray::from_option_iter([Some(1.1f32), None, Some(2.2), Some(3.3), None]);
139        let values = vec![1.1f32, 2.2, 3.3];
140        let encoder = RDEncoder::new(&values);
141        encoder.encode(&arr)
142    })]
143    fn test_cast_alprd_conformance(#[case] alprd: crate::alp_rd::ALPRDArray) {
144        test_cast_conformance(&alprd.into_array());
145    }
146}