vortex_array/expr/
arbitrary.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::max;
5
6use arbitrary::Result as AResult;
7use arbitrary::Unstructured;
8use vortex_dtype::DType;
9use vortex_dtype::FieldName;
10use vortex_scalar::arbitrary::random_scalar;
11
12use crate::expr::Binary;
13use crate::expr::Expression;
14use crate::expr::Operator;
15use crate::expr::VTableExt;
16use crate::expr::and_collect;
17use crate::expr::col;
18use crate::expr::lit;
19use crate::expr::pack;
20
21pub fn projection_expr(u: &mut Unstructured<'_>, dtype: &DType) -> AResult<Option<Expression>> {
22    let Some(struct_dtype) = dtype.as_struct_fields_opt() else {
23        return Ok(None);
24    };
25
26    let column_count = u.int_in_range::<usize>(0..=max(struct_dtype.nfields(), 10))?;
27
28    let cols = (0..column_count)
29        .map(|_| {
30            let get_item = u.choose_iter(struct_dtype.names().iter())?;
31            Ok((get_item.clone(), col(get_item.clone())))
32        })
33        .collect::<AResult<Vec<_>>>()?;
34
35    Ok(Some(pack(cols, u.arbitrary()?)))
36}
37
38pub fn filter_expr(u: &mut Unstructured<'_>, dtype: &DType) -> AResult<Option<Expression>> {
39    let Some(struct_dtype) = dtype.as_struct_fields_opt() else {
40        return Ok(None);
41    };
42
43    let filter_count = u.int_in_range::<usize>(0..=max(struct_dtype.nfields(), 10))?;
44
45    let filters = (0..filter_count)
46        .map(|_| {
47            let (col, dtype) =
48                u.choose_iter(struct_dtype.names().iter().zip(struct_dtype.fields()))?;
49            random_comparison(u, col, &dtype)
50        })
51        .collect::<AResult<Vec<_>>>()?;
52
53    Ok(and_collect(filters))
54}
55
56fn random_comparison(
57    u: &mut Unstructured<'_>,
58    name: &FieldName,
59    dtype: &DType,
60) -> AResult<Expression> {
61    let scalar = random_scalar(u, dtype)?;
62    Ok(Binary.new_expr(
63        arbitrary_comparison_operator(u)?,
64        [col(name.clone()), lit(scalar)],
65    ))
66}
67
68fn arbitrary_comparison_operator(u: &mut Unstructured<'_>) -> AResult<Operator> {
69    Ok(match u.int_in_range(0..=5)? {
70        0 => Operator::Eq,
71        1 => Operator::NotEq,
72        2 => Operator::Gt,
73        3 => Operator::Gte,
74        4 => Operator::Lt,
75        5 => Operator::Lte,
76        _ => unreachable!("range 0..=5"),
77    })
78}