vortex_dtype/
arbitrary.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use arbitrary::{Arbitrary, Result, Unstructured};
7use vortex_error::VortexExpect;
8
9use crate::{
10    DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DType, DecimalDType, FieldName, FieldNames,
11    Nullability, PType, StructFields,
12};
13
14impl<'a> Arbitrary<'a> for DType {
15    fn arbitrary(u: &mut Unstructured<'a>) -> Result<Self> {
16        random_dtype(u, 2)
17    }
18}
19
20impl<'a> Arbitrary<'a> for FieldName {
21    fn arbitrary(u: &mut Unstructured<'a>) -> Result<Self> {
22        let i: Arc<str> = Arbitrary::arbitrary(u)?;
23        Ok(Self::from(i))
24    }
25}
26
27fn random_dtype(u: &mut Unstructured<'_>, depth: u8) -> Result<DType> {
28    const BASE_TYPE_COUNT: i32 = 5;
29    const CONTAINER_TYPE_COUNT: i32 = 3;
30    let max_dtype_kind = if depth == 0 {
31        BASE_TYPE_COUNT
32    } else {
33        CONTAINER_TYPE_COUNT + BASE_TYPE_COUNT
34    };
35    Ok(match u.int_in_range(1..=max_dtype_kind)? {
36        // base types
37        1 => DType::Bool(u.arbitrary()?),
38        2 => DType::Primitive(u.arbitrary()?, u.arbitrary()?),
39        3 => DType::Decimal(u.arbitrary()?, u.arbitrary()?),
40        4 => DType::Utf8(u.arbitrary()?),
41        5 => DType::Binary(u.arbitrary()?),
42
43        // container types
44        6 => DType::Struct(random_struct_dtype(u, depth - 1)?, u.arbitrary()?),
45        7 => DType::List(Arc::new(random_dtype(u, depth - 1)?), u.arbitrary()?),
46        8 => DType::FixedSizeList(
47            Arc::new(random_dtype(u, depth - 1)?),
48            // We limit the list size to 3 rather (following random struct fields).
49            u.choose_index(3)?.try_into().vortex_expect("impossible"),
50            u.arbitrary()?,
51        ),
52        // Null,
53        // Extension(ExtDType, Nullability),
54        _ => unreachable!("Number out of range"),
55    })
56}
57
58impl<'a> Arbitrary<'a> for Nullability {
59    fn arbitrary(u: &mut Unstructured<'a>) -> Result<Self> {
60        Ok(if u.arbitrary()? {
61            Nullability::Nullable
62        } else {
63            Nullability::NonNullable
64        })
65    }
66}
67
68impl<'a> Arbitrary<'a> for PType {
69    fn arbitrary(u: &mut Unstructured<'a>) -> Result<Self> {
70        Ok(match u.int_in_range(0..=10)? {
71            0 => PType::U8,
72            1 => PType::U16,
73            2 => PType::U32,
74            3 => PType::U64,
75            4 => PType::I8,
76            5 => PType::I16,
77            6 => PType::I32,
78            7 => PType::I64,
79            8 => PType::F16,
80            9 => PType::F32,
81            10 => PType::F64,
82            _ => unreachable!("Number out of range"),
83        })
84    }
85}
86
87impl<'a> Arbitrary<'a> for DecimalDType {
88    #[allow(
89        clippy::unwrap_in_result,
90        clippy::expect_used,
91        clippy::cast_possible_truncation
92    )]
93    fn arbitrary(u: &mut Unstructured<'a>) -> Result<Self> {
94        // Get a random integer for the scale
95        let precision = u.int_in_range(1..=DECIMAL256_MAX_PRECISION)?;
96        let scale = u.int_in_range(-DECIMAL256_MAX_SCALE..=(precision as i8))?;
97        Ok(Self::new(precision, scale))
98    }
99}
100
101impl<'a> Arbitrary<'a> for StructFields {
102    fn arbitrary(u: &mut Unstructured<'a>) -> Result<Self> {
103        random_struct_dtype(u, 1)
104    }
105}
106
107fn random_struct_dtype(u: &mut Unstructured<'_>, depth: u8) -> Result<StructFields> {
108    let field_count = u.choose_index(3)?;
109    let names: FieldNames = (0..field_count)
110        .map(|_| FieldName::arbitrary(u))
111        .collect::<Result<FieldNames>>()?;
112    let dtypes = (0..names.len())
113        .map(|_| random_dtype(u, depth))
114        .collect::<Result<Vec<_>>>()?;
115    Ok(StructFields::new(names, dtypes))
116}