1use std::ops::Range;
4
5use arrow2::{
6 array::{Array, PrimitiveArray},
7 bitmap::Bitmap,
8 types::NativeType,
9};
10use sample_std::{
11 arbitrary, sampler_choice, valid_f32, valid_f64, Arbitrary, ArbitrarySampler, Random, Sample,
12 Shrunk, VecSampler,
13};
14
15use crate::{ArrowLenSampler, ArrowSampler, SetLen};
16
17#[derive(Debug, Clone)]
18pub struct PrimitiveArraySampler<PT, V> {
19 len: usize,
20 inner: PT,
21 validity: V,
22}
23
24impl<PT, V: SetLen> SetLen for PrimitiveArraySampler<PT, V> {
25 fn set_len(&mut self, len: usize) {
26 self.len = len;
27 self.validity.set_len(len);
28 }
29}
30
31impl<PT, V> Sample for PrimitiveArraySampler<PT, V>
32where
33 PT: Sample,
34 PT::Output: NativeType,
35 V: Sample<Output = Option<Bitmap>> + SetLen,
36{
37 type Output = PrimitiveArray<PT::Output>;
38
39 fn generate(&mut self, g: &mut Random) -> Self::Output {
40 let vec = (0..self.len).map(|_| self.inner.generate(g)).collect();
41 let mut arr = PrimitiveArray::from_vec(vec);
42 arr.set_validity(self.validity.generate(g));
43 arr
44 }
45
46 fn shrink(&self, _: Self::Output) -> Box<dyn Iterator<Item = Self::Output>> {
47 Box::new(std::iter::empty())
48 }
49}
50
51impl<PT, V> PrimitiveArraySampler<PT, V>
52where
53 PT: Sample + 'static,
54 PT::Output: NativeType,
55 V: Sample<Output = Option<Bitmap>> + SetLen + 'static,
56{
57 fn boxed(self) -> ArrowLenSampler {
58 fn unbox<T: NativeType>(boxed: Box<dyn Array>) -> Option<PrimitiveArray<T>> {
59 boxed.as_any().downcast_ref::<PrimitiveArray<T>>().cloned()
60 }
61 Box::new(self.try_convert(PrimitiveArray::boxed, unbox))
62 }
63}
64
65pub fn primitive_len_sampler<PT, V>(inner: PT, validity: V) -> ArrowLenSampler
66where
67 PT: Sample + 'static,
68 PT::Output: NativeType,
69 V: Sample<Output = Option<Bitmap>> + SetLen + 'static,
70{
71 PrimitiveArraySampler {
72 len: 0,
73 inner,
74 validity,
75 }
76 .boxed()
77}
78
79pub fn arbitrary_len_sampler<T, V>(validity: V) -> ArrowLenSampler
80where
81 V: Sample<Output = Option<Bitmap>> + SetLen + 'static,
82 T: NativeType + Arbitrary,
83{
84 primitive_len_sampler(arbitrary::<T>(), validity)
85}
86
87pub fn valid_float_len_sampler<V>(valid: V) -> ArrowLenSampler
88where
89 V: Sample<Output = Option<Bitmap>> + SetLen + Clone + 'static,
90{
91 Box::new(sampler_choice(vec![
92 primitive_len_sampler(valid_f32(), valid.clone()),
94 primitive_len_sampler(valid_f64(), valid),
95 ]))
96}
97
98pub fn arbitrary_float_len_sampler<V>(valid: V) -> ArrowLenSampler
99where
100 V: Sample<Output = Option<Bitmap>> + SetLen + Clone + 'static,
101{
102 Box::new(sampler_choice(vec![
103 arbitrary_len_sampler::<f32, _>(valid.clone()),
105 arbitrary_len_sampler::<f32, _>(valid),
106 ]))
107}
108
109pub fn arbitrary_int_len_sampler<V>(valid: V) -> ArrowLenSampler
110where
111 V: Sample<Output = Option<Bitmap>> + SetLen + Clone + 'static,
112{
113 Box::new(sampler_choice(vec![
114 arbitrary_len_sampler::<i8, _>(valid.clone()),
115 arbitrary_len_sampler::<i16, _>(valid.clone()),
116 arbitrary_len_sampler::<i32, _>(valid.clone()),
117 arbitrary_len_sampler::<i64, _>(valid.clone()),
118 ]))
119}
120
121pub fn arbitrary_uint_len_sampler<V>(valid: V) -> ArrowLenSampler
124where
125 V: Sample<Output = Option<Bitmap>> + SetLen + Clone + 'static,
126{
127 Box::new(sampler_choice(vec![
128 arbitrary_len_sampler::<u8, _>(valid.clone()),
129 arbitrary_len_sampler::<u16, _>(valid.clone()),
130 arbitrary_len_sampler::<u32, _>(valid.clone()),
131 arbitrary_len_sampler::<u64, _>(valid.clone()),
132 ]))
133}
134
135pub fn valid_primitive_len<V>(valid: V) -> ArrowLenSampler
136where
137 V: Sample<Output = Option<Bitmap>> + SetLen + Clone + 'static,
138{
139 Box::new(sampler_choice([
140 valid_float_len_sampler(valid.clone()),
141 arbitrary_int_len_sampler(valid.clone()),
142 arbitrary_uint_len_sampler(valid.clone()),
143 ]))
144}
145
146pub fn arbitrary_primitive_len<V>(valid: V) -> ArrowLenSampler
147where
148 V: Sample<Output = Option<Bitmap>> + SetLen + Clone + 'static,
149{
150 Box::new(sampler_choice([
151 arbitrary_float_len_sampler(valid.clone()),
152 arbitrary_int_len_sampler(valid.clone()),
153 arbitrary_uint_len_sampler(valid.clone()),
154 ]))
155}
156
157#[derive(Debug, Clone)]
158pub struct ProtoNullablePrimitiveArray<PT> {
159 inner: VecSampler<Range<usize>, PT>,
160}
161
162fn to_primitive<T>(vec: Vec<Option<T>>) -> PrimitiveArray<T>
163where
164 T: NativeType,
165{
166 PrimitiveArray::from_trusted_len_iter(vec.into_iter())
167}
168
169impl<PT, T> Sample for ProtoNullablePrimitiveArray<PT>
170where
171 PT: Sample<Output = Option<T>> + Clone + 'static,
172 T: NativeType + Arbitrary,
173{
174 type Output = PrimitiveArray<T>;
175
176 fn generate(&mut self, g: &mut Random) -> Self::Output {
177 to_primitive(self.inner.generate(g))
178 }
179
180 fn shrink(&self, v: Self::Output) -> Shrunk<Self::Output> {
181 let vec = v.iter().map(|el| el.cloned()).collect();
182 Box::new(self.inner.shrink(vec).map(to_primitive))
183 }
184}
185
186#[derive(Debug, Clone)]
187pub struct ProtoPrimitiveArray<PT> {
188 inner: VecSampler<Range<usize>, PT>,
189}
190
191impl<PT, T> Sample for ProtoPrimitiveArray<PT>
192where
193 PT: Sample<Output = T> + Clone + 'static,
194 T: NativeType,
195{
196 type Output = PrimitiveArray<T>;
197
198 fn generate(&mut self, g: &mut Random) -> Self::Output {
199 PrimitiveArray::from_trusted_len_values_iter(self.inner.generate(g).into_iter())
200 }
201
202 fn shrink(&self, v: Self::Output) -> Shrunk<Self::Output> {
203 let vec = v.values_iter().cloned().collect();
204 Box::new(
205 self.inner
206 .shrink(vec)
207 .map(IntoIterator::into_iter)
208 .map(PrimitiveArray::from_trusted_len_values_iter),
209 )
210 }
211}
212
213pub fn boxed_nullable<GT, T>(len: Range<usize>, el: GT) -> ArrowSampler
214where
215 GT: Sample<Output = Option<T>> + Send + Sync + Clone + 'static,
216 T: NativeType + Arbitrary,
217{
218 Box::new(ProtoBoxedNullablePrimitiveArray {
219 inner: ProtoNullablePrimitiveArray {
220 inner: VecSampler { length: len, el },
221 },
222 })
223}
224
225pub fn boxed<GT, T>(len: Range<usize>, el: GT) -> ArrowSampler
226where
227 GT: Sample<Output = T> + Send + Sync + Clone + 'static,
228 T: NativeType + Arbitrary,
229{
230 Box::new(
231 ProtoPrimitiveArray {
232 inner: VecSampler { length: len, el },
233 }
234 .try_convert(PrimitiveArray::boxed, |boxed| {
235 if boxed.validity().is_none() {
236 boxed.as_any().downcast_ref::<PrimitiveArray<T>>().cloned()
237 } else {
238 None
239 }
240 }),
241 )
242}
243
244#[derive(Clone)]
245pub struct ProtoBoxedNullablePrimitiveArray<PT> {
246 inner: ProtoNullablePrimitiveArray<PT>,
247}
248
249impl<GT, T> Sample for ProtoBoxedNullablePrimitiveArray<GT>
250where
251 GT: Sample<Output = Option<T>> + Clone + 'static,
252 T: NativeType + Arbitrary,
253{
254 type Output = Box<dyn Array>;
255
256 fn generate(&mut self, g: &mut Random) -> Self::Output {
257 self.inner.generate(g).boxed()
258 }
259
260 fn shrink(&self, v: Self::Output) -> Shrunk<Self::Output> {
261 Box::new(
262 v.as_any()
263 .downcast_ref::<PrimitiveArray<T>>()
264 .cloned()
265 .into_iter()
266 .flat_map(move |arr| self.inner.shrink(arr.clone()).map(|arr| arr.boxed())),
267 )
268 }
269}
270
271#[derive(Clone)]
272struct Nullable<SI, V> {
273 inner: SI,
274 null: V,
275}
276
277impl<SI, V> Sample for Nullable<SI, V>
278where
279 SI: Sample,
280 V: Sample<Output = bool>,
281{
282 type Output = Option<SI::Output>;
283 fn generate(&mut self, g: &mut Random) -> Self::Output {
284 if self.null.generate(g) {
285 None
286 } else {
287 Some(self.inner.generate(g))
288 }
289 }
290
291 fn shrink(&self, v: Self::Output) -> Shrunk<Self::Output> {
292 if let Some(v) = v {
293 Box::new(std::iter::once(None).chain(self.inner.shrink(v).map(Some)))
294 } else {
295 Box::new(std::iter::empty())
296 }
297 }
298}
299
300pub fn boxed_primitive<T, V>(
301 el: ArbitrarySampler<T>,
302 len: Range<usize>,
303 null: Option<V>,
304) -> ArrowSampler
305where
306 T: Arbitrary + NativeType,
307 V: Sample<Output = bool> + Clone + Send + Sync + 'static,
308{
309 match null {
310 Some(null) => boxed_nullable(len.clone(), Nullable { inner: el, null }),
311 None => boxed(len.clone(), el),
312 }
313}
314
315pub fn arbitrary_boxed_primitive<T, V>(len: Range<usize>, null: Option<V>) -> ArrowSampler
316where
317 V: Sample<Output = bool> + Clone + Send + Sync + 'static,
318 T: NativeType + Arbitrary,
319{
320 boxed_primitive(arbitrary::<T>(), len, null)
321}
322
323pub fn valid_float_array<V>(len: Range<usize>, null: Option<V>) -> ArrowSampler
326where
327 V: Sample<Output = bool> + Clone + Send + Sync + 'static,
328{
329 Box::new(sampler_choice(vec![
330 boxed_primitive(valid_f32(), len.clone(), null.clone()),
332 boxed_primitive(valid_f64(), len, null),
333 ]))
334}
335
336pub fn arbitrary_float_array<V>(len: Range<usize>, null: Option<V>) -> ArrowSampler
337where
338 V: Sample<Output = bool> + Clone + Send + Sync + 'static,
339{
340 Box::new(sampler_choice(vec![
341 arbitrary_boxed_primitive::<f32, _>(len.clone(), null.clone()),
343 arbitrary_boxed_primitive::<f32, _>(len, null),
344 ]))
345}
346
347pub fn arbitrary_int_array<V>(len: Range<usize>, null: Option<V>) -> ArrowSampler
348where
349 V: Sample<Output = bool> + Clone + Send + Sync + 'static,
350{
351 Box::new(sampler_choice(vec![
352 arbitrary_boxed_primitive::<i8, _>(len.clone(), null.clone()),
353 arbitrary_boxed_primitive::<i16, _>(len.clone(), null.clone()),
354 arbitrary_boxed_primitive::<i32, _>(len.clone(), null.clone()),
355 arbitrary_boxed_primitive::<i64, _>(len.clone(), null.clone()),
356 ]))
357}
358
359pub fn arbitrary_uint_array<V>(len: Range<usize>, null: Option<V>) -> ArrowSampler
362where
363 V: Sample<Output = bool> + Clone + Send + Sync + 'static,
364{
365 Box::new(sampler_choice(vec![
366 arbitrary_boxed_primitive::<u8, _>(len.clone(), null.clone()),
367 arbitrary_boxed_primitive::<u16, _>(len.clone(), null.clone()),
368 arbitrary_boxed_primitive::<u32, _>(len.clone(), null.clone()),
369 arbitrary_boxed_primitive::<u64, _>(len.clone(), null.clone()),
370 ]))
371}
372
373pub fn valid_primitive<V>(len: Range<usize>, null: Option<V>) -> ArrowSampler
374where
375 V: Sample<Output = bool> + Clone + Send + Sync + 'static,
376{
377 Box::new(sampler_choice([
378 valid_float_array(len.clone(), null.clone()),
379 arbitrary_int_array(len.clone(), null.clone()),
380 arbitrary_uint_array(len.clone(), null.clone()),
381 ]))
382}
383
384pub fn arbitrary_primitive<V>(len: Range<usize>, null: Option<V>) -> ArrowSampler
385where
386 V: Sample<Output = bool> + Clone + Send + Sync + 'static,
387{
388 Box::new(sampler_choice([
389 arbitrary_float_array(len.clone(), null.clone()),
390 arbitrary_int_array(len.clone(), null.clone()),
391 arbitrary_uint_array(len.clone(), null.clone()),
392 ]))
393}
394
395#[cfg(test)]
396mod tests {
397 use sample_std::Chance;
398
399 use super::*;
400
401 #[test]
402 fn gen_float() {
403 let mut gen = valid_float_array(50..51, Some(Chance(0.5)));
404 let mut r = Random::new();
405 let arr = gen.generate(&mut r);
406 assert_eq!(arr, arr);
407 }
408}