vortex_array/arrays/constant/compute/
take.rs1use vortex_error::VortexResult;
5use vortex_mask::AllOr;
6
7use crate::Array;
8use crate::ArrayRef;
9use crate::IntoArray;
10use crate::arrays::ConstantArray;
11use crate::arrays::ConstantVTable;
12use crate::arrays::MaskedArray;
13use crate::arrays::TakeReduce;
14use crate::arrays::TakeReduceAdaptor;
15use crate::optimizer::rules::ParentRuleSet;
16use crate::scalar::Scalar;
17use crate::validity::Validity;
18
19impl TakeReduce for ConstantVTable {
20 fn take(array: &ConstantArray, indices: &dyn Array) -> VortexResult<Option<ArrayRef>> {
21 let result = match indices.validity_mask()?.bit_buffer() {
22 AllOr::All => {
23 let scalar = Scalar::try_new(
24 array
25 .scalar()
26 .dtype()
27 .union_nullability(indices.dtype().nullability()),
28 array.scalar().value().cloned(),
29 )?;
30 ConstantArray::new(scalar, indices.len()).into_array()
31 }
32 AllOr::None => ConstantArray::new(
33 Scalar::null(
34 array
35 .dtype()
36 .union_nullability(indices.dtype().nullability()),
37 ),
38 indices.len(),
39 )
40 .into_array(),
41 AllOr::Some(v) => {
42 let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
43
44 if array.scalar().is_null() {
45 return Ok(Some(arr));
46 }
47
48 MaskedArray::try_new(arr, Validity::from(v.clone()))?.into_array()
49 }
50 };
51 Ok(Some(result))
52 }
53}
54
55impl ConstantVTable {
56 pub const TAKE_RULES: ParentRuleSet<Self> =
57 ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::<Self>(Self))]);
58}
59
60#[cfg(test)]
61mod tests {
62 use rstest::rstest;
63 use vortex_buffer::buffer;
64 use vortex_dtype::Nullability;
65 use vortex_mask::AllOr;
66
67 use crate::Array;
68 use crate::IntoArray;
69 use crate::ToCanonical;
70 use crate::arrays::ConstantArray;
71 use crate::arrays::PrimitiveArray;
72 use crate::assert_arrays_eq;
73 use crate::compute::conformance::take::test_take_conformance;
74 use crate::scalar::Scalar;
75 use crate::validity::Validity;
76
77 #[test]
78 fn take_nullable_indices() {
79 let array = ConstantArray::new(42, 10).to_array();
80 let taken = array
81 .take(
82 PrimitiveArray::new(
83 buffer![0, 5, 7],
84 Validity::from_iter(vec![false, true, false]),
85 )
86 .into_array(),
87 )
88 .unwrap();
89 let valid_indices: &[usize] = &[1usize];
90 assert_eq!(
91 &array.dtype().with_nullability(Nullability::Nullable),
92 taken.dtype()
93 );
94 assert_arrays_eq!(
95 taken.to_primitive(),
96 PrimitiveArray::new(
97 buffer![42i32, 42, 42],
98 Validity::from_iter([false, true, false])
99 )
100 );
101 assert_eq!(
102 taken.validity_mask().unwrap().indices(),
103 AllOr::Some(valid_indices)
104 );
105 }
106
107 #[test]
108 fn take_all_valid_indices() {
109 let array = ConstantArray::new(42, 10).to_array();
110 let taken = array
111 .take(PrimitiveArray::new(buffer![0, 5, 7], Validity::AllValid).into_array())
112 .unwrap();
113 assert_eq!(
114 &array.dtype().with_nullability(Nullability::Nullable),
115 taken.dtype()
116 );
117 assert_arrays_eq!(
118 taken.to_primitive(),
119 PrimitiveArray::new(buffer![42i32, 42, 42], Validity::AllValid)
120 );
121 assert_eq!(taken.validity_mask().unwrap().indices(), AllOr::All);
122 }
123
124 #[rstest]
125 #[case(ConstantArray::new(42i32, 5))]
126 #[case(ConstantArray::new(std::f64::consts::PI, 10))]
127 #[case(ConstantArray::new(Scalar::from("hello"), 3))]
128 #[case(ConstantArray::new(Scalar::null_native::<i64>(), 5))]
129 #[case(ConstantArray::new(true, 1))]
130 fn test_take_constant_conformance(#[case] array: ConstantArray) {
131 test_take_conformance(array.as_ref());
132 }
133}