Skip to main content

vortex_alp/alp_rd/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::IntoArray;
7use vortex_array::builtins::ArrayBuiltins;
8use vortex_array::dtype::DType;
9use vortex_array::scalar_fn::fns::cast::CastReduce;
10use vortex_error::VortexResult;
11
12use crate::ALPRDArrayExt;
13use crate::alp_rd::ALPRD;
14
15impl CastReduce for ALPRD {
16    fn cast(array: ArrayView<'_, Self>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
17        // Check if this is just a nullability change
18        if !array.dtype().eq_ignore_nullability(dtype) {
19            return Ok(None);
20        }
21
22        // For nullability-only changes, we need to cast the left_parts array
23        // since it carries the validity information
24        let new_left_parts = array.left_parts().cast(
25            array
26                .left_parts()
27                .dtype()
28                .with_nullability(dtype.nullability()),
29        )?;
30
31        // NOTE: `CastReduce::cast` has a fixed trait signature without `ExecutionCtx`, so we
32        // construct a legacy ctx locally at this trait boundary.
33        Ok(Some(
34            unsafe {
35                ALPRD::new_unchecked(
36                    dtype.clone(),
37                    new_left_parts,
38                    array.left_parts_dictionary().clone(),
39                    array.right_parts().clone(),
40                    array.right_bit_width(),
41                    array.left_parts_patches(),
42                )
43            }
44            .into_array(),
45        ))
46    }
47}
48
49#[cfg(test)]
50mod tests {
51    use rstest::rstest;
52    use vortex_array::IntoArray;
53    use vortex_array::LEGACY_SESSION;
54    use vortex_array::VortexSessionExecute;
55    use vortex_array::arrays::PrimitiveArray;
56    use vortex_array::builtins::ArrayBuiltins;
57    use vortex_array::compute::conformance::cast::test_cast_conformance;
58    use vortex_array::dtype::DType;
59    use vortex_array::dtype::Nullability;
60    use vortex_array::dtype::PType;
61
62    use crate::RDEncoder;
63
64    #[test]
65    fn test_cast_alprd_f32_to_f64() {
66        let mut ctx = LEGACY_SESSION.create_execution_ctx();
67        let values = vec![1.0f32, 1.1, 1.2, 1.3, 1.4];
68        let arr = PrimitiveArray::from_iter(values.clone());
69        let encoder = RDEncoder::new(&values);
70        let alprd = encoder.encode(arr.as_view(), &mut ctx);
71
72        let casted = alprd
73            .into_array()
74            .cast(DType::Primitive(PType::F64, Nullability::NonNullable))
75            .unwrap();
76        assert_eq!(
77            casted.dtype(),
78            &DType::Primitive(PType::F64, Nullability::NonNullable)
79        );
80
81        let decoded = casted.execute::<PrimitiveArray>(&mut ctx).unwrap();
82        let f64_values = decoded.as_slice::<f64>();
83        assert_eq!(f64_values.len(), 5);
84        assert!((f64_values[0] - 1.0).abs() < f64::EPSILON);
85        assert!((f64_values[1] - 1.1).abs() < 1e-6); // Use larger epsilon for f32->f64 conversion
86    }
87
88    #[test]
89    fn test_cast_alprd_nullable() {
90        let mut ctx = LEGACY_SESSION.create_execution_ctx();
91        let arr =
92            PrimitiveArray::from_option_iter([Some(10.0f64), None, Some(10.1), Some(10.2), None]);
93        let values = vec![10.0f64, 10.1, 10.2];
94        let encoder = RDEncoder::new(&values);
95        let alprd = encoder.encode(arr.as_view(), &mut ctx);
96
97        // Cast to NonNullable should fail since we have nulls. The failure surfaces during
98        // execution since the reduce path defers when the validity stat is not cached.
99        let result = alprd
100            .clone()
101            .into_array()
102            .cast(DType::Primitive(PType::F64, Nullability::NonNullable))
103            .and_then(|a| {
104                a.execute::<PrimitiveArray>(&mut ctx)
105                    .map(|p| p.into_array())
106            });
107        assert!(result.is_err(), "Expected error, got: {result:?}");
108
109        // Cast to same type with Nullable should succeed
110        let casted = alprd
111            .into_array()
112            .cast(DType::Primitive(PType::F64, Nullability::Nullable))
113            .unwrap();
114        assert_eq!(
115            casted.dtype(),
116            &DType::Primitive(PType::F64, Nullability::Nullable)
117        );
118    }
119
120    #[rstest]
121    #[case::f32({
122        let values = vec![1.23f32, 4.56, 7.89, 10.11, 12.13];
123        let arr = PrimitiveArray::from_iter(values.clone());
124        let encoder = RDEncoder::new(&values);
125        encoder.encode(arr.as_view(), &mut LEGACY_SESSION.create_execution_ctx())
126    })]
127    #[case::f64({
128        let values = vec![100.1f64, 200.2, 300.3, 400.4, 500.5];
129        let arr = PrimitiveArray::from_iter(values.clone());
130        let encoder = RDEncoder::new(&values);
131        encoder.encode(arr.as_view(), &mut LEGACY_SESSION.create_execution_ctx())
132    })]
133    #[case::single({
134        let values = vec![42.42f64];
135        let arr = PrimitiveArray::from_iter(values.clone());
136        let encoder = RDEncoder::new(&values);
137        encoder.encode(arr.as_view(), &mut LEGACY_SESSION.create_execution_ctx())
138    })]
139    #[case::negative({
140        let values = vec![0.0f32, -1.5, 2.5, -3.5, 4.5];
141        let arr = PrimitiveArray::from_iter(values.clone());
142        let encoder = RDEncoder::new(&values);
143        encoder.encode(arr.as_view(), &mut LEGACY_SESSION.create_execution_ctx())
144    })]
145    #[case::nullable({
146        let arr = PrimitiveArray::from_option_iter([Some(1.1f32), None, Some(2.2), Some(3.3), None]);
147        let values = vec![1.1f32, 2.2, 3.3];
148        let encoder = RDEncoder::new(&values);
149        encoder.encode(arr.as_view(), &mut LEGACY_SESSION.create_execution_ctx())
150    })]
151    fn test_cast_alprd_conformance(#[case] alprd: crate::alp_rd::ALPRDArray) {
152        test_cast_conformance(&alprd.into_array());
153    }
154}