Skip to main content

vortex_alp/alp_rd/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::IntoArray;
6use vortex_array::builtins::ArrayBuiltins;
7use vortex_array::dtype::DType;
8use vortex_array::scalar_fn::fns::cast::CastReduce;
9use vortex_error::VortexResult;
10
11use crate::alp_rd::ALPRDArray;
12use crate::alp_rd::ALPRDVTable;
13
14impl CastReduce for ALPRDVTable {
15    fn cast(array: &ALPRDArray, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
16        // ALPRDArray stores floating-point values, so only cast between float types
17        // or if just changing nullability
18
19        // Check if this is just a nullability change
20        if array.dtype().eq_ignore_nullability(dtype) {
21            // For nullability-only changes, we need to cast the left_parts array
22            // since it carries the validity information
23            let new_left_parts = array.left_parts().cast(
24                array
25                    .left_parts()
26                    .dtype()
27                    .with_nullability(dtype.nullability()),
28            )?;
29
30            return Ok(Some(
31                ALPRDArray::try_new(
32                    dtype.clone(),
33                    new_left_parts,
34                    array.left_parts_dictionary().clone(),
35                    array.right_parts().clone(),
36                    array.right_bit_width(),
37                    array.left_parts_patches().cloned(),
38                )?
39                .into_array(),
40            ));
41        }
42
43        // For other casts (e.g., f32 to f64), decode to canonical and let PrimitiveArray handle it
44        Ok(None)
45    }
46}
47
48#[cfg(test)]
49mod tests {
50    use rstest::rstest;
51    use vortex_array::IntoArray;
52    use vortex_array::ToCanonical;
53    use vortex_array::arrays::PrimitiveArray;
54    use vortex_array::builtins::ArrayBuiltins;
55    use vortex_array::compute::conformance::cast::test_cast_conformance;
56    use vortex_array::dtype::DType;
57    use vortex_array::dtype::Nullability;
58    use vortex_array::dtype::PType;
59
60    use crate::RDEncoder;
61
62    #[test]
63    fn test_cast_alprd_f32_to_f64() {
64        let values = vec![1.0f32, 1.1, 1.2, 1.3, 1.4];
65        let arr = PrimitiveArray::from_iter(values.clone());
66        let encoder = RDEncoder::new(&values);
67        let alprd = encoder.encode(&arr);
68
69        let casted = alprd
70            .into_array()
71            .cast(DType::Primitive(PType::F64, Nullability::NonNullable))
72            .unwrap();
73        assert_eq!(
74            casted.dtype(),
75            &DType::Primitive(PType::F64, Nullability::NonNullable)
76        );
77
78        let decoded = casted.to_primitive();
79        let f64_values = decoded.as_slice::<f64>();
80        assert_eq!(f64_values.len(), 5);
81        assert!((f64_values[0] - 1.0).abs() < f64::EPSILON);
82        assert!((f64_values[1] - 1.1).abs() < 1e-6); // Use larger epsilon for f32->f64 conversion
83    }
84
85    #[test]
86    fn test_cast_alprd_nullable() {
87        let arr =
88            PrimitiveArray::from_option_iter([Some(10.0f64), None, Some(10.1), Some(10.2), None]);
89        let values = vec![10.0f64, 10.1, 10.2];
90        let encoder = RDEncoder::new(&values);
91        let alprd = encoder.encode(&arr);
92
93        // Cast to NonNullable should fail since we have nulls
94        let result = alprd
95            .clone()
96            .into_array()
97            .cast(DType::Primitive(PType::F64, Nullability::NonNullable));
98        assert!(result.is_err());
99
100        // Cast to same type with Nullable should succeed
101        let casted = alprd
102            .into_array()
103            .cast(DType::Primitive(PType::F64, Nullability::Nullable))
104            .unwrap();
105        assert_eq!(
106            casted.dtype(),
107            &DType::Primitive(PType::F64, Nullability::Nullable)
108        );
109    }
110
111    #[rstest]
112    #[case::f32({
113        let values = vec![1.23f32, 4.56, 7.89, 10.11, 12.13];
114        let arr = PrimitiveArray::from_iter(values.clone());
115        let encoder = RDEncoder::new(&values);
116        encoder.encode(&arr)
117    })]
118    #[case::f64({
119        let values = vec![100.1f64, 200.2, 300.3, 400.4, 500.5];
120        let arr = PrimitiveArray::from_iter(values.clone());
121        let encoder = RDEncoder::new(&values);
122        encoder.encode(&arr)
123    })]
124    #[case::single({
125        let values = vec![42.42f64];
126        let arr = PrimitiveArray::from_iter(values.clone());
127        let encoder = RDEncoder::new(&values);
128        encoder.encode(&arr)
129    })]
130    #[case::negative({
131        let values = vec![0.0f32, -1.5, 2.5, -3.5, 4.5];
132        let arr = PrimitiveArray::from_iter(values.clone());
133        let encoder = RDEncoder::new(&values);
134        encoder.encode(&arr)
135    })]
136    #[case::nullable({
137        let arr = PrimitiveArray::from_option_iter([Some(1.1f32), None, Some(2.2), Some(3.3), None]);
138        let values = vec![1.1f32, 2.2, 3.3];
139        let encoder = RDEncoder::new(&values);
140        encoder.encode(&arr)
141    })]
142    fn test_cast_alprd_conformance(#[case] alprd: crate::alp_rd::ALPRDArray) {
143        test_cast_conformance(&alprd.into_array());
144    }
145}