Skip to main content

vortex_array/arrays/scalar_fn/vtable/
validity.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use vortex_error::VortexResult;
5
6use crate::ArrayRef;
7use crate::IntoArray;
8use crate::LEGACY_SESSION;
9use crate::VortexSessionExecute;
10use crate::array::ArrayView;
11use crate::array::ValidityVTable;
12use crate::arrays::scalar_fn::ScalarFnArrayExt;
13use crate::arrays::scalar_fn::vtable::ArrayExpr;
14use crate::arrays::scalar_fn::vtable::FakeEq;
15use crate::arrays::scalar_fn::vtable::ScalarFnVTable;
16use crate::expr::Expression;
17use crate::expr::lit;
18use crate::scalar_fn::ScalarFn;
19use crate::scalar_fn::VecExecutionArgs;
20use crate::scalar_fn::fns::literal::Literal;
21use crate::scalar_fn::fns::root::Root;
22use crate::validity::Validity;
23
24/// Execute an expression tree recursively.
25///
26/// This assumes all leaf expressions are either ArrayExpr (wrapping actual arrays) or Literals.
27fn execute_expr(expr: &Expression, row_count: usize) -> VortexResult<ArrayRef> {
28    let mut ctx = LEGACY_SESSION.create_execution_ctx();
29
30    // Handle Root expression - this should not happen in validity expressions
31    if expr.is::<Root>() {
32        vortex_error::vortex_bail!("Root expression cannot be executed in validity context");
33    }
34
35    // Handle Literal expression - create a constant array
36    if expr.is::<Literal>() {
37        let scalar = expr.as_::<Literal>();
38        return Ok(crate::arrays::ConstantArray::new(scalar.clone(), row_count).into_array());
39    }
40
41    // Recursively execute child expressions to get input arrays
42    let inputs: Vec<ArrayRef> = expr
43        .children()
44        .iter()
45        .map(|child| execute_expr(child, row_count))
46        .collect::<VortexResult<_>>()?;
47
48    let args = VecExecutionArgs::new(inputs, row_count);
49
50    Ok(expr.scalar_fn().execute(&args, &mut ctx)?.into_array())
51}
52
53impl ValidityVTable<ScalarFnVTable> for ScalarFnVTable {
54    fn validity(array: ArrayView<'_, ScalarFnVTable>) -> VortexResult<Validity> {
55        let inputs: Vec<_> = array
56            .iter_children()
57            .map(|child| {
58                if let Some(scalar) = child.as_constant() {
59                    return Ok(lit(scalar));
60                }
61                Expression::try_new(ScalarFn::new(ArrayExpr, FakeEq(child.clone())).erased(), [])
62            })
63            .collect::<VortexResult<_>>()?;
64
65        let expr = Expression::try_new(array.scalar_fn().clone(), inputs)?;
66        let validity_expr = array.scalar_fn().validity(&expr)?;
67
68        // Execute the validity expression. All leaves are ArrayExpr nodes.
69        Ok(Validity::Array(execute_expr(&validity_expr, array.len())?))
70    }
71}