sample_arrow2/
primitive.rs

1//! Samplers for generating an arrow [`PrimitiveArray`].
2
3use std::ops::Range;
4
5use arrow2::{
6    array::{Array, PrimitiveArray},
7    bitmap::Bitmap,
8    types::NativeType,
9};
10use sample_std::{
11    arbitrary, sampler_choice, valid_f32, valid_f64, Arbitrary, ArbitrarySampler, Random, Sample,
12    Shrunk, VecSampler,
13};
14
15use crate::{ArrowLenSampler, ArrowSampler, SetLen};
16
17#[derive(Debug, Clone)]
18pub struct PrimitiveArraySampler<PT, V> {
19    len: usize,
20    inner: PT,
21    validity: V,
22}
23
24impl<PT, V: SetLen> SetLen for PrimitiveArraySampler<PT, V> {
25    fn set_len(&mut self, len: usize) {
26        self.len = len;
27        self.validity.set_len(len);
28    }
29}
30
31impl<PT, V> Sample for PrimitiveArraySampler<PT, V>
32where
33    PT: Sample,
34    PT::Output: NativeType,
35    V: Sample<Output = Option<Bitmap>> + SetLen,
36{
37    type Output = PrimitiveArray<PT::Output>;
38
39    fn generate(&mut self, g: &mut Random) -> Self::Output {
40        let vec = (0..self.len).map(|_| self.inner.generate(g)).collect();
41        let mut arr = PrimitiveArray::from_vec(vec);
42        arr.set_validity(self.validity.generate(g));
43        arr
44    }
45
46    fn shrink(&self, _: Self::Output) -> Box<dyn Iterator<Item = Self::Output>> {
47        Box::new(std::iter::empty())
48    }
49}
50
51impl<PT, V> PrimitiveArraySampler<PT, V>
52where
53    PT: Sample + 'static,
54    PT::Output: NativeType,
55    V: Sample<Output = Option<Bitmap>> + SetLen + 'static,
56{
57    fn boxed(self) -> ArrowLenSampler {
58        fn unbox<T: NativeType>(boxed: Box<dyn Array>) -> Option<PrimitiveArray<T>> {
59            boxed.as_any().downcast_ref::<PrimitiveArray<T>>().cloned()
60        }
61        Box::new(self.try_convert(PrimitiveArray::boxed, unbox))
62    }
63}
64
65pub fn primitive_len_sampler<PT, V>(inner: PT, validity: V) -> ArrowLenSampler
66where
67    PT: Sample + 'static,
68    PT::Output: NativeType,
69    V: Sample<Output = Option<Bitmap>> + SetLen + 'static,
70{
71    PrimitiveArraySampler {
72        len: 0,
73        inner,
74        validity,
75    }
76    .boxed()
77}
78
79pub fn arbitrary_len_sampler<T, V>(validity: V) -> ArrowLenSampler
80where
81    V: Sample<Output = Option<Bitmap>> + SetLen + 'static,
82    T: NativeType + Arbitrary,
83{
84    primitive_len_sampler(arbitrary::<T>(), validity)
85}
86
87pub fn valid_float_len_sampler<V>(valid: V) -> ArrowLenSampler
88where
89    V: Sample<Output = Option<Bitmap>> + SetLen + Clone + 'static,
90{
91    Box::new(sampler_choice(vec![
92        // todo: f16
93        primitive_len_sampler(valid_f32(), valid.clone()),
94        primitive_len_sampler(valid_f64(), valid),
95    ]))
96}
97
98pub fn arbitrary_float_len_sampler<V>(valid: V) -> ArrowLenSampler
99where
100    V: Sample<Output = Option<Bitmap>> + SetLen + Clone + 'static,
101{
102    Box::new(sampler_choice(vec![
103        // todo: f16
104        arbitrary_len_sampler::<f32, _>(valid.clone()),
105        arbitrary_len_sampler::<f32, _>(valid),
106    ]))
107}
108
109pub fn arbitrary_int_len_sampler<V>(valid: V) -> ArrowLenSampler
110where
111    V: Sample<Output = Option<Bitmap>> + SetLen + Clone + 'static,
112{
113    Box::new(sampler_choice(vec![
114        arbitrary_len_sampler::<i8, _>(valid.clone()),
115        arbitrary_len_sampler::<i16, _>(valid.clone()),
116        arbitrary_len_sampler::<i32, _>(valid.clone()),
117        arbitrary_len_sampler::<i64, _>(valid.clone()),
118    ]))
119}
120
121// todo: arbitrary_monthsdaysnano_array
122
123pub fn arbitrary_uint_len_sampler<V>(valid: V) -> ArrowLenSampler
124where
125    V: Sample<Output = Option<Bitmap>> + SetLen + Clone + 'static,
126{
127    Box::new(sampler_choice(vec![
128        arbitrary_len_sampler::<u8, _>(valid.clone()),
129        arbitrary_len_sampler::<u16, _>(valid.clone()),
130        arbitrary_len_sampler::<u32, _>(valid.clone()),
131        arbitrary_len_sampler::<u64, _>(valid.clone()),
132    ]))
133}
134
135pub fn valid_primitive_len<V>(valid: V) -> ArrowLenSampler
136where
137    V: Sample<Output = Option<Bitmap>> + SetLen + Clone + 'static,
138{
139    Box::new(sampler_choice([
140        valid_float_len_sampler(valid.clone()),
141        arbitrary_int_len_sampler(valid.clone()),
142        arbitrary_uint_len_sampler(valid.clone()),
143    ]))
144}
145
146pub fn arbitrary_primitive_len<V>(valid: V) -> ArrowLenSampler
147where
148    V: Sample<Output = Option<Bitmap>> + SetLen + Clone + 'static,
149{
150    Box::new(sampler_choice([
151        arbitrary_float_len_sampler(valid.clone()),
152        arbitrary_int_len_sampler(valid.clone()),
153        arbitrary_uint_len_sampler(valid.clone()),
154    ]))
155}
156
157#[derive(Debug, Clone)]
158pub struct ProtoNullablePrimitiveArray<PT> {
159    inner: VecSampler<Range<usize>, PT>,
160}
161
162fn to_primitive<T>(vec: Vec<Option<T>>) -> PrimitiveArray<T>
163where
164    T: NativeType,
165{
166    PrimitiveArray::from_trusted_len_iter(vec.into_iter())
167}
168
169impl<PT, T> Sample for ProtoNullablePrimitiveArray<PT>
170where
171    PT: Sample<Output = Option<T>> + Clone + 'static,
172    T: NativeType + Arbitrary,
173{
174    type Output = PrimitiveArray<T>;
175
176    fn generate(&mut self, g: &mut Random) -> Self::Output {
177        to_primitive(self.inner.generate(g))
178    }
179
180    fn shrink(&self, v: Self::Output) -> Shrunk<Self::Output> {
181        let vec = v.iter().map(|el| el.cloned()).collect();
182        Box::new(self.inner.shrink(vec).map(to_primitive))
183    }
184}
185
186#[derive(Debug, Clone)]
187pub struct ProtoPrimitiveArray<PT> {
188    inner: VecSampler<Range<usize>, PT>,
189}
190
191impl<PT, T> Sample for ProtoPrimitiveArray<PT>
192where
193    PT: Sample<Output = T> + Clone + 'static,
194    T: NativeType,
195{
196    type Output = PrimitiveArray<T>;
197
198    fn generate(&mut self, g: &mut Random) -> Self::Output {
199        PrimitiveArray::from_trusted_len_values_iter(self.inner.generate(g).into_iter())
200    }
201
202    fn shrink(&self, v: Self::Output) -> Shrunk<Self::Output> {
203        let vec = v.values_iter().cloned().collect();
204        Box::new(
205            self.inner
206                .shrink(vec)
207                .map(IntoIterator::into_iter)
208                .map(PrimitiveArray::from_trusted_len_values_iter),
209        )
210    }
211}
212
213pub fn boxed_nullable<GT, T>(len: Range<usize>, el: GT) -> ArrowSampler
214where
215    GT: Sample<Output = Option<T>> + Send + Sync + Clone + 'static,
216    T: NativeType + Arbitrary,
217{
218    Box::new(ProtoBoxedNullablePrimitiveArray {
219        inner: ProtoNullablePrimitiveArray {
220            inner: VecSampler { length: len, el },
221        },
222    })
223}
224
225pub fn boxed<GT, T>(len: Range<usize>, el: GT) -> ArrowSampler
226where
227    GT: Sample<Output = T> + Send + Sync + Clone + 'static,
228    T: NativeType + Arbitrary,
229{
230    Box::new(
231        ProtoPrimitiveArray {
232            inner: VecSampler { length: len, el },
233        }
234        .try_convert(PrimitiveArray::boxed, |boxed| {
235            if boxed.validity().is_none() {
236                boxed.as_any().downcast_ref::<PrimitiveArray<T>>().cloned()
237            } else {
238                None
239            }
240        }),
241    )
242}
243
244#[derive(Clone)]
245pub struct ProtoBoxedNullablePrimitiveArray<PT> {
246    inner: ProtoNullablePrimitiveArray<PT>,
247}
248
249impl<GT, T> Sample for ProtoBoxedNullablePrimitiveArray<GT>
250where
251    GT: Sample<Output = Option<T>> + Clone + 'static,
252    T: NativeType + Arbitrary,
253{
254    type Output = Box<dyn Array>;
255
256    fn generate(&mut self, g: &mut Random) -> Self::Output {
257        self.inner.generate(g).boxed()
258    }
259
260    fn shrink(&self, v: Self::Output) -> Shrunk<Self::Output> {
261        Box::new(
262            v.as_any()
263                .downcast_ref::<PrimitiveArray<T>>()
264                .cloned()
265                .into_iter()
266                .flat_map(move |arr| self.inner.shrink(arr.clone()).map(|arr| arr.boxed())),
267        )
268    }
269}
270
271#[derive(Clone)]
272struct Nullable<SI, V> {
273    inner: SI,
274    null: V,
275}
276
277impl<SI, V> Sample for Nullable<SI, V>
278where
279    SI: Sample,
280    V: Sample<Output = bool>,
281{
282    type Output = Option<SI::Output>;
283    fn generate(&mut self, g: &mut Random) -> Self::Output {
284        if self.null.generate(g) {
285            None
286        } else {
287            Some(self.inner.generate(g))
288        }
289    }
290
291    fn shrink(&self, v: Self::Output) -> Shrunk<Self::Output> {
292        if let Some(v) = v {
293            Box::new(std::iter::once(None).chain(self.inner.shrink(v).map(Some)))
294        } else {
295            Box::new(std::iter::empty())
296        }
297    }
298}
299
300pub fn boxed_primitive<T, V>(
301    el: ArbitrarySampler<T>,
302    len: Range<usize>,
303    null: Option<V>,
304) -> ArrowSampler
305where
306    T: Arbitrary + NativeType,
307    V: Sample<Output = bool> + Clone + Send + Sync + 'static,
308{
309    match null {
310        Some(null) => boxed_nullable(len.clone(), Nullable { inner: el, null }),
311        None => boxed(len.clone(), el),
312    }
313}
314
315pub fn arbitrary_boxed_primitive<T, V>(len: Range<usize>, null: Option<V>) -> ArrowSampler
316where
317    V: Sample<Output = bool> + Clone + Send + Sync + 'static,
318    T: NativeType + Arbitrary,
319{
320    boxed_primitive(arbitrary::<T>(), len, null)
321}
322
323// todo: arbitrary_daysms_array
324
325pub fn valid_float_array<V>(len: Range<usize>, null: Option<V>) -> ArrowSampler
326where
327    V: Sample<Output = bool> + Clone + Send + Sync + 'static,
328{
329    Box::new(sampler_choice(vec![
330        // todo: f16
331        boxed_primitive(valid_f32(), len.clone(), null.clone()),
332        boxed_primitive(valid_f64(), len, null),
333    ]))
334}
335
336pub fn arbitrary_float_array<V>(len: Range<usize>, null: Option<V>) -> ArrowSampler
337where
338    V: Sample<Output = bool> + Clone + Send + Sync + 'static,
339{
340    Box::new(sampler_choice(vec![
341        // todo: f16
342        arbitrary_boxed_primitive::<f32, _>(len.clone(), null.clone()),
343        arbitrary_boxed_primitive::<f32, _>(len, null),
344    ]))
345}
346
347pub fn arbitrary_int_array<V>(len: Range<usize>, null: Option<V>) -> ArrowSampler
348where
349    V: Sample<Output = bool> + Clone + Send + Sync + 'static,
350{
351    Box::new(sampler_choice(vec![
352        arbitrary_boxed_primitive::<i8, _>(len.clone(), null.clone()),
353        arbitrary_boxed_primitive::<i16, _>(len.clone(), null.clone()),
354        arbitrary_boxed_primitive::<i32, _>(len.clone(), null.clone()),
355        arbitrary_boxed_primitive::<i64, _>(len.clone(), null.clone()),
356    ]))
357}
358
359// todo: arbitrary_monthsdaysnano_array
360
361pub fn arbitrary_uint_array<V>(len: Range<usize>, null: Option<V>) -> ArrowSampler
362where
363    V: Sample<Output = bool> + Clone + Send + Sync + 'static,
364{
365    Box::new(sampler_choice(vec![
366        arbitrary_boxed_primitive::<u8, _>(len.clone(), null.clone()),
367        arbitrary_boxed_primitive::<u16, _>(len.clone(), null.clone()),
368        arbitrary_boxed_primitive::<u32, _>(len.clone(), null.clone()),
369        arbitrary_boxed_primitive::<u64, _>(len.clone(), null.clone()),
370    ]))
371}
372
373pub fn valid_primitive<V>(len: Range<usize>, null: Option<V>) -> ArrowSampler
374where
375    V: Sample<Output = bool> + Clone + Send + Sync + 'static,
376{
377    Box::new(sampler_choice([
378        valid_float_array(len.clone(), null.clone()),
379        arbitrary_int_array(len.clone(), null.clone()),
380        arbitrary_uint_array(len.clone(), null.clone()),
381    ]))
382}
383
384pub fn arbitrary_primitive<V>(len: Range<usize>, null: Option<V>) -> ArrowSampler
385where
386    V: Sample<Output = bool> + Clone + Send + Sync + 'static,
387{
388    Box::new(sampler_choice([
389        arbitrary_float_array(len.clone(), null.clone()),
390        arbitrary_int_array(len.clone(), null.clone()),
391        arbitrary_uint_array(len.clone(), null.clone()),
392    ]))
393}
394
395#[cfg(test)]
396mod tests {
397    use sample_std::Chance;
398
399    use super::*;
400
401    #[test]
402    fn gen_float() {
403        let mut gen = valid_float_array(50..51, Some(Chance(0.5)));
404        let mut r = Random::new();
405        let arr = gen.generate(&mut r);
406        assert_eq!(arr, arr);
407    }
408}