Skip to main content

vortex_array/arrays/varbinview/compute/
cast.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use vortex_error::VortexResult;
7
8use crate::ArrayRef;
9use crate::ExecutionCtx;
10use crate::IntoArray;
11use crate::array::ArrayView;
12use crate::arrays::VarBinView;
13use crate::arrays::VarBinViewArray;
14use crate::dtype::DType;
15use crate::scalar_fn::fns::cast::CastKernel;
16use crate::scalar_fn::fns::cast::CastReduce;
17use crate::validity::Validity;
18
19fn build_with_validity(
20    array: ArrayView<'_, VarBinView>,
21    new_dtype: DType,
22    new_validity: Validity,
23) -> ArrayRef {
24    // SAFETY: casting just changes the DType, does not affect invariants on views/buffers.
25    unsafe {
26        VarBinViewArray::new_handle_unchecked(
27            array.views_handle().clone(),
28            Arc::clone(array.data_buffers()),
29            new_dtype,
30            new_validity,
31        )
32        .into_array()
33    }
34}
35
36impl CastReduce for VarBinView {
37    fn cast(array: ArrayView<'_, VarBinView>, dtype: &DType) -> VortexResult<Option<ArrayRef>> {
38        if !array.dtype().eq_ignore_nullability(dtype) {
39            return Ok(None);
40        }
41
42        let new_nullability = dtype.nullability();
43        let Some(new_validity) = array
44            .validity()?
45            .trivial_cast_nullability(new_nullability, array.len())?
46        else {
47            return Ok(None);
48        };
49        let new_dtype = array.dtype().with_nullability(new_nullability);
50        Ok(Some(build_with_validity(array, new_dtype, new_validity)))
51    }
52}
53
54impl CastKernel for VarBinView {
55    fn cast(
56        array: ArrayView<'_, VarBinView>,
57        dtype: &DType,
58        ctx: &mut ExecutionCtx,
59    ) -> VortexResult<Option<ArrayRef>> {
60        if !array.dtype().eq_ignore_nullability(dtype) {
61            return Ok(None);
62        }
63
64        let new_nullability = dtype.nullability();
65        let new_validity = array
66            .validity()?
67            .cast_nullability(new_nullability, array.len(), ctx)?;
68        let new_dtype = array.dtype().with_nullability(new_nullability);
69        Ok(Some(build_with_validity(array, new_dtype, new_validity)))
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use std::sync::LazyLock;
76
77    use rstest::rstest;
78    use vortex_session::VortexSession;
79
80    use crate::Canonical;
81    use crate::IntoArray;
82    use crate::VortexSessionExecute;
83    use crate::arrays::VarBinViewArray;
84    use crate::builtins::ArrayBuiltins;
85    use crate::compute::conformance::cast::test_cast_conformance;
86    use crate::dtype::DType;
87    use crate::dtype::Nullability;
88    use crate::session::ArraySession;
89
90    static SESSION: LazyLock<VortexSession> =
91        LazyLock::new(|| VortexSession::empty().with::<ArraySession>());
92
93    #[rstest]
94    #[case(
95        DType::Utf8(Nullability::Nullable),
96        DType::Utf8(Nullability::NonNullable)
97    )]
98    #[case(
99        DType::Binary(Nullability::Nullable),
100        DType::Binary(Nullability::NonNullable)
101    )]
102    #[case(
103        DType::Utf8(Nullability::NonNullable),
104        DType::Utf8(Nullability::Nullable)
105    )]
106    #[case(
107        DType::Binary(Nullability::NonNullable),
108        DType::Binary(Nullability::Nullable)
109    )]
110    fn try_cast_varbin_nullable(#[case] source: DType, #[case] target: DType) {
111        let varbin = VarBinViewArray::from_iter(vec![Some("a"), Some("b"), Some("c")], source);
112
113        let res = varbin.into_array().cast(target.clone());
114        assert_eq!(res.unwrap().dtype(), &target);
115    }
116
117    #[rstest]
118    #[case(DType::Utf8(Nullability::Nullable))]
119    #[case(DType::Binary(Nullability::Nullable))]
120    fn try_cast_varbin_fail(#[case] source: DType) {
121        // Failure surfaces during execution via the kernel.
122        let non_nullable_source = source.as_nonnullable();
123        let varbin = VarBinViewArray::from_iter(vec![Some("a"), Some("b"), None], source);
124        let mut ctx = SESSION.create_execution_ctx();
125        let result = varbin
126            .into_array()
127            .cast(non_nullable_source)
128            .and_then(|a| a.execute::<Canonical>(&mut ctx).map(|c| c.into_array()));
129        assert!(result.is_err(), "Expected error, got: {result:?}");
130    }
131
132    #[rstest]
133    #[case(VarBinViewArray::from_iter(vec![Some("hello"), Some("world"), Some("test")], DType::Utf8(Nullability::NonNullable)))]
134    #[case(VarBinViewArray::from_iter(vec![Some("hello"), None, Some("world")], DType::Utf8(Nullability::Nullable)))]
135    #[case(VarBinViewArray::from_iter(vec![Some(b"binary".as_slice()), Some(b"data".as_slice())], DType::Binary(Nullability::NonNullable)))]
136    #[case(VarBinViewArray::from_iter(vec![Some(b"test".as_slice()), None], DType::Binary(Nullability::Nullable)))]
137    #[case(VarBinViewArray::from_iter(vec![Some("single")], DType::Utf8(Nullability::NonNullable)))]
138    #[case(VarBinViewArray::from_iter(vec![Some("very long string that exceeds the inline size to test view functionality with multiple buffers")], DType::Utf8(Nullability::NonNullable)))]
139    fn test_cast_varbinview_conformance(#[case] array: VarBinViewArray) {
140        test_cast_conformance(&array.into_array());
141    }
142}