Skip to main content

vortex_array/arrays/varbin/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::ArrayRef;
7use crate::ExecutionCtx;
8use crate::IntoArray;
9use crate::array::ArrayView;
10use crate::arrays::VarBin;
11use crate::arrays::VarBinArray;
12use crate::arrays::varbin::VarBinArrayExt;
13use crate::dtype::DType;
14use crate::scalar_fn::fns::cast::CastKernel;
15use crate::scalar_fn::fns::cast::CastReduce;
16use crate::validity::Validity;
17
18fn build_with_validity(
19    array: ArrayView<'_, VarBin>,
20    new_dtype: DType,
21    new_validity: Validity,
22) -> VortexResult<ArrayRef> {
23    Ok(VarBinArray::try_new(
24        array.offsets().clone(),
25        array.bytes().clone(),
26        new_dtype,
27        new_validity,
28    )?
29    .into_array())
30}
31
32impl CastReduce for VarBin {
33    fn cast(array: ArrayView<'_, VarBin>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
34        if !array.dtype().eq_ignore_nullability(dtype) {
35            return Ok(None);
36        }
37
38        let new_nullability = dtype.nullability();
39        let Some(new_validity) = array
40            .validity()?
41            .trivial_cast_nullability(new_nullability, array.len())?
42        else {
43            return Ok(None);
44        };
45        let new_dtype = array.dtype().with_nullability(new_nullability);
46        Ok(Some(build_with_validity(array, new_dtype, new_validity)?))
47    }
48}
49
50impl CastKernel for VarBin {
51    fn cast(
52        array: ArrayView<'_, VarBin>,
53        dtype: &DType,
54        ctx: &mut ExecutionCtx,
55    ) -> VortexResult<Option<ArrayRef>> {
56        if !array.dtype().eq_ignore_nullability(dtype) {
57            return Ok(None);
58        }
59
60        let new_nullability = dtype.nullability();
61        let new_validity = array
62            .validity()?
63            .cast_nullability(new_nullability, array.len(), ctx)?;
64        let new_dtype = array.dtype().with_nullability(new_nullability);
65        Ok(Some(build_with_validity(array, new_dtype, new_validity)?))
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use std::sync::LazyLock;
72
73    use rstest::rstest;
74    use vortex_session::VortexSession;
75
76    use crate::Canonical;
77    use crate::IntoArray;
78    use crate::VortexSessionExecute;
79    use crate::arrays::VarBinArray;
80    use crate::builtins::ArrayBuiltins;
81    use crate::compute::conformance::cast::test_cast_conformance;
82    use crate::dtype::DType;
83    use crate::dtype::Nullability;
84    use crate::session::ArraySession;
85
86    static SESSION: LazyLock<VortexSession> =
87        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
88
89    #[rstest]
90    #[case(
91        DType::Utf8(Nullability::Nullable),
92        DType::Utf8(Nullability::NonNullable)
93    )]
94    #[case(
95        DType::Binary(Nullability::Nullable),
96        DType::Binary(Nullability::NonNullable)
97    )]
98    #[case(
99        DType::Utf8(Nullability::NonNullable),
100        DType::Utf8(Nullability::Nullable)
101    )]
102    #[case(
103        DType::Binary(Nullability::NonNullable),
104        DType::Binary(Nullability::Nullable)
105    )]
106    fn try_cast_varbin_nullable(#[case] source: DType, #[case] target: DType) {
107        let varbin = VarBinArray::from_iter(vec![Some("a"), Some("b"), Some("c")], source);
108
109        let res = varbin.into_array().cast(target.clone());
110        assert_eq!(res.unwrap().dtype(), &target);
111    }
112
113    #[rstest]
114    #[case(DType::Utf8(Nullability::Nullable))]
115    #[case(DType::Binary(Nullability::Nullable))]
116    fn try_cast_varbin_fail(#[case] source: DType) {
117        // Failure surfaces during execution via the kernel.
118        let non_nullable_source = source.as_nonnullable();
119        let varbin = VarBinArray::from_iter(vec![Some("a"), Some("b"), None], source);
120        let mut ctx = SESSION.create_execution_ctx();
121        let result = varbin
122            .into_array()
123            .cast(non_nullable_source)
124            .and_then(|a| a.execute::<Canonical>(&mut ctx).map(|c| c.into_array()));
125        assert!(result.is_err(), "Expected error, got: {result:?}");
126    }
127
128    #[rstest]
129    #[case(VarBinArray::from_iter(vec![Some("hello"), Some("world"), Some("test")], DType::Utf8(Nullability::NonNullable)))]
130    #[case(VarBinArray::from_iter(vec![Some("hello"), None, Some("world")], DType::Utf8(Nullability::Nullable)))]
131    #[case(VarBinArray::from_iter(vec![Some(b"binary".as_slice()), Some(b"data".as_slice())], DType::Binary(Nullability::NonNullable)))]
132    #[case(VarBinArray::from_iter(vec![Some(b"test".as_slice()), None], DType::Binary(Nullability::Nullable)))]
133    #[case(VarBinArray::from_iter(vec![Some("single")], DType::Utf8(Nullability::NonNullable)))]
134    fn test_cast_varbin_conformance(#[case] array: VarBinArray) {
135        test_cast_conformance(&array.into_array());
136    }
137}