vortex_array/arrays/constant/compute/
take.rs1use vortex_error::VortexResult;
5use vortex_mask::AllOr;
6
7use crate::ArrayRef;
8use crate::IntoArray;
9use crate::LEGACY_SESSION;
10use crate::VortexSessionExecute;
11use crate::array::ArrayView;
12use crate::arrays::Constant;
13use crate::arrays::ConstantArray;
14use crate::arrays::MaskedArray;
15use crate::arrays::dict::TakeReduce;
16use crate::arrays::dict::TakeReduceAdaptor;
17use crate::optimizer::rules::ParentRuleSet;
18use crate::scalar::Scalar;
19use crate::validity::Validity;
20
21impl TakeReduce for Constant {
22 fn take(array: ArrayView<'_, Constant>, indices: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
23 let mut ctx = LEGACY_SESSION.create_execution_ctx();
24 let result = match indices
25 .validity()?
26 .to_mask(indices.len(), &mut ctx)?
27 .bit_buffer()
28 {
29 AllOr::All => {
30 let scalar = Scalar::try_new(
31 array
32 .scalar()
33 .dtype()
34 .union_nullability(indices.dtype().nullability()),
35 array.scalar().value().cloned(),
36 )?;
37 ConstantArray::new(scalar, indices.len()).into_array()
38 }
39 AllOr::None => ConstantArray::new(
40 Scalar::null(
41 array
42 .dtype()
43 .union_nullability(indices.dtype().nullability()),
44 ),
45 indices.len(),
46 )
47 .into_array(),
48 AllOr::Some(v) => {
49 let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
50
51 if array.scalar().is_null() {
52 return Ok(Some(arr));
53 }
54
55 MaskedArray::try_new(arr, Validity::from(v.clone()))?.into_array()
56 }
57 };
58 Ok(Some(result))
59 }
60}
61
62impl Constant {
63 pub const TAKE_RULES: ParentRuleSet<Self> =
64 ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::<Self>(Self))]);
65}
66
67#[cfg(test)]
68mod tests {
69 use rstest::rstest;
70 use vortex_buffer::buffer;
71 use vortex_mask::AllOr;
72
73 use crate::IntoArray;
74 use crate::LEGACY_SESSION;
75 use crate::ToCanonical;
76 use crate::VortexSessionExecute;
77 use crate::arrays::ConstantArray;
78 use crate::arrays::PrimitiveArray;
79 use crate::assert_arrays_eq;
80 use crate::compute::conformance::take::test_take_conformance;
81 use crate::dtype::Nullability;
82 use crate::scalar::Scalar;
83 use crate::validity::Validity;
84
85 #[test]
86 fn take_nullable_indices() {
87 let array = ConstantArray::new(42, 10).into_array();
88 let taken = array
89 .take(
90 PrimitiveArray::new(
91 buffer![0, 5, 7],
92 Validity::from_iter(vec![false, true, false]),
93 )
94 .into_array(),
95 )
96 .unwrap();
97 let valid_indices: &[usize] = &[1usize];
98 assert_eq!(
99 &array.dtype().with_nullability(Nullability::Nullable),
100 taken.dtype()
101 );
102 assert_arrays_eq!(
103 taken.to_primitive(),
104 PrimitiveArray::new(
105 buffer![42i32, 42, 42],
106 Validity::from_iter([false, true, false])
107 )
108 );
109 assert_eq!(
110 taken
111 .validity()
112 .unwrap()
113 .to_mask(taken.len(), &mut LEGACY_SESSION.create_execution_ctx())
114 .unwrap()
115 .indices(),
116 AllOr::Some(valid_indices)
117 );
118 }
119
120 #[test]
121 fn take_all_valid_indices() {
122 let array = ConstantArray::new(42, 10).into_array();
123 let taken = array
124 .take(PrimitiveArray::new(buffer![0, 5, 7], Validity::AllValid).into_array())
125 .unwrap();
126 assert_eq!(
127 &array.dtype().with_nullability(Nullability::Nullable),
128 taken.dtype()
129 );
130 assert_arrays_eq!(
131 taken.to_primitive(),
132 PrimitiveArray::new(buffer![42i32, 42, 42], Validity::AllValid)
133 );
134 assert_eq!(
135 taken
136 .validity()
137 .unwrap()
138 .to_mask(taken.len(), &mut LEGACY_SESSION.create_execution_ctx())
139 .unwrap()
140 .indices(),
141 AllOr::All
142 );
143 }
144
145 #[rstest]
146 #[case(ConstantArray::new(42i32, 5))]
147 #[case(ConstantArray::new(std::f64::consts::PI, 10))]
148 #[case(ConstantArray::new(Scalar::from("hello"), 3))]
149 #[case(ConstantArray::new(Scalar::null_native::<i64>(), 5))]
150 #[case(ConstantArray::new(true, 1))]
151 fn test_take_constant_conformance(#[case] array: ConstantArray) {
152 test_take_conformance(&array.into_array());
153 }
154}