vortex_expr/
arbitrary.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::cmp::max;
5
6use arbitrary::{Result as AResult, Unstructured};
7use vortex_dtype::{DType, FieldName};
8use vortex_scalar::arbitrary::random_scalar;
9
10use crate::{BinaryExpr, ExprRef, Operator, and_collect, col, lit, pack};
11
12pub fn projection_expr(u: &mut Unstructured<'_>, dtype: &DType) -> AResult<Option<ExprRef>> {
13    let Some(struct_dtype) = dtype.as_struct_fields_opt() else {
14        return Ok(None);
15    };
16
17    let column_count = u.int_in_range::<usize>(0..=max(struct_dtype.nfields(), 10))?;
18
19    let cols = (0..column_count)
20        .map(|_| {
21            let get_item = u.choose_iter(struct_dtype.names().iter())?;
22            Ok((get_item.clone(), col(get_item.clone())))
23        })
24        .collect::<AResult<Vec<_>>>()?;
25
26    Ok(Some(pack(cols, u.arbitrary()?)))
27}
28
29pub fn filter_expr(u: &mut Unstructured<'_>, dtype: &DType) -> AResult<Option<ExprRef>> {
30    let Some(struct_dtype) = dtype.as_struct_fields_opt() else {
31        return Ok(None);
32    };
33
34    let filter_count = u.int_in_range::<usize>(0..=max(struct_dtype.nfields(), 10))?;
35
36    let filters = (0..filter_count)
37        .filter_map(|_| {
38            match u.choose_iter(struct_dtype.names().iter().zip(struct_dtype.fields())) {
39                Ok((col, dtype)) => {
40                    if dtype.is_struct() || dtype.is_list() {
41                        None
42                    } else {
43                        Some(random_comparison(u, col, &dtype))
44                    }
45                }
46                Err(e) => Some(Err(e)),
47            }
48        })
49        .collect::<AResult<Vec<_>>>()?;
50
51    Ok(and_collect(filters))
52}
53
54fn random_comparison(
55    u: &mut Unstructured<'_>,
56    name: &FieldName,
57    dtype: &DType,
58) -> AResult<ExprRef> {
59    let scalar = random_scalar(u, dtype)?;
60    Ok(BinaryExpr::new_expr(
61        col(name.clone()),
62        arbitrary_comparison_operator(u)?,
63        lit(scalar),
64    ))
65}
66
67fn arbitrary_comparison_operator(u: &mut Unstructured<'_>) -> AResult<Operator> {
68    Ok(match u.int_in_range(0..=5)? {
69        0 => Operator::Eq,
70        1 => Operator::NotEq,
71        2 => Operator::Gt,
72        3 => Operator::Gte,
73        4 => Operator::Lt,
74        5 => Operator::Lte,
75        _ => unreachable!("range 0..=5"),
76    })
77}