sample_arrow2/
list.rs

1//! Samplers for generating an arrow [`ListArray`].
2
3use std::ops::Range;
4
5use crate::{array::ArraySampler, generate_validity, SampleLen, SetLen};
6use arrow2::{
7    array::{Array, ListArray},
8    bitmap::Bitmap,
9    datatypes::{DataType, Field},
10    offset::OffsetsBuffer,
11};
12use sample_std::{Random, Sample};
13
14pub struct ListWithLen<V, C, A, N> {
15    pub len: usize,
16    pub validity: V,
17    pub count: C,
18
19    pub inner: A,
20    pub inner_name: N,
21}
22
23impl<V: SetLen, C, A, N> SetLen for ListWithLen<V, C, A, N> {
24    fn set_len(&mut self, len: usize) {
25        self.len = len;
26        self.validity.set_len(len);
27    }
28}
29
30impl<V, C, A, N> Sample for ListWithLen<V, C, A, N>
31where
32    V: Sample<Output = Option<Bitmap>> + SetLen,
33    C: Sample<Output = i32>,
34    A: Sample<Output = Box<dyn Array>> + SetLen,
35    N: Sample<Output = String>,
36{
37    type Output = Box<dyn Array>;
38
39    fn generate(&mut self, g: &mut Random) -> Self::Output {
40        let mut offsets = vec![0];
41        let mut inner_len: i32 = 0;
42        for _ in 0..self.len {
43            let count = self.count.generate(g);
44            assert!(count >= 0);
45            inner_len += count;
46            offsets.push(inner_len);
47        }
48
49        self.inner.set_len(inner_len as usize);
50        let values = self.inner.generate(g);
51        let is_nullable = values.validity().is_some();
52        let inner_name = self.inner_name.generate(g);
53        let field = Field::new(inner_name, values.data_type().clone(), is_nullable);
54        let data_type = DataType::List(Box::new(field));
55
56        // SAFETY: see loop above. starts at zero, asserts all increments are positive.
57        let offsets = OffsetsBuffer::try_from(offsets).unwrap();
58        let validity = self.validity.generate(g);
59        ListArray::new(data_type, offsets, values, validity).boxed()
60    }
61
62    fn shrink(&self, _: Self::Output) -> Box<dyn Iterator<Item = Self::Output>> {
63        Box::new(std::iter::empty())
64    }
65}
66
67impl<V, C, A, N> SampleLen for ListWithLen<V, C, A, N>
68where
69    V: Sample<Output = Option<Bitmap>> + SetLen,
70    C: Sample<Output = i32>,
71    A: Sample<Output = Box<dyn Array>> + SetLen,
72    N: Sample<Output = String>,
73{
74}
75
76pub struct ListSampler<V> {
77    pub data_type: DataType,
78    pub null: Option<V>,
79    pub len: Range<usize>,
80    pub inner: ArraySampler,
81}
82
83impl<V> Sample for ListSampler<V>
84where
85    V: Sample<Output = bool> + Send + Sync + 'static,
86{
87    type Output = Box<dyn Array>;
88
89    fn generate(&mut self, g: &mut Random) -> Self::Output {
90        let values = self.inner.generate(g);
91        let len = self.len.generate(g);
92        let mut ix = 0;
93        let mut offsets = vec![0];
94
95        for outer_ix in 0..len {
96            if outer_ix + 1 != len {
97                let remaining = values.len() - ix;
98                let fair = std::cmp::max(2, remaining / (len - outer_ix));
99                let upper = std::cmp::min(values.len() - ix, fair);
100                let count = g.gen_range(0..=upper);
101                ix += count;
102                offsets.push(ix as i32);
103            } else {
104                offsets.push(values.len() as i32);
105            }
106        }
107
108        let validity = generate_validity(&mut self.null, g, len);
109
110        ListArray::new(
111            self.data_type.clone(),
112            OffsetsBuffer::try_from(offsets).unwrap(),
113            values,
114            validity,
115        )
116        .boxed()
117    }
118}