vortex_array/arrays/constant/compute/
take.rs1use vortex_error::VortexResult;
5use vortex_mask::AllOr;
6
7use crate::ArrayRef;
8use crate::IntoArray;
9use crate::LEGACY_SESSION;
10use crate::VortexSessionExecute;
11use crate::array::ArrayView;
12use crate::arrays::Constant;
13use crate::arrays::ConstantArray;
14use crate::arrays::MaskedArray;
15use crate::arrays::dict::TakeReduce;
16use crate::arrays::dict::TakeReduceAdaptor;
17use crate::optimizer::rules::ParentRuleSet;
18use crate::scalar::Scalar;
19use crate::validity::Validity;
20
21impl TakeReduce for Constant {
22 fn take(array: ArrayView<'_, Constant>, indices: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
23 let mut ctx = LEGACY_SESSION.create_execution_ctx();
24 let result = match indices
25 .validity()?
26 .execute_mask(indices.len(), &mut ctx)?
27 .bit_buffer()
28 {
29 AllOr::All => {
30 let scalar = Scalar::try_new(
31 array
32 .scalar()
33 .dtype()
34 .union_nullability(indices.dtype().nullability()),
35 array.scalar().value().cloned(),
36 )?;
37 ConstantArray::new(scalar, indices.len()).into_array()
38 }
39 AllOr::None => ConstantArray::new(
40 Scalar::null(
41 array
42 .dtype()
43 .union_nullability(indices.dtype().nullability()),
44 ),
45 indices.len(),
46 )
47 .into_array(),
48 AllOr::Some(v) => {
49 let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
50
51 if array.scalar().is_null() {
52 return Ok(Some(arr));
53 }
54
55 MaskedArray::try_new(arr, Validity::from(v.clone()))?.into_array()
56 }
57 };
58 Ok(Some(result))
59 }
60}
61
62impl Constant {
63 pub const TAKE_RULES: ParentRuleSet<Self> =
64 ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::<Self>(Self))]);
65}
66
67#[cfg(test)]
68mod tests {
69 use std::f64;
70
71 use rstest::rstest;
72 use vortex_buffer::buffer;
73 use vortex_mask::AllOr;
74
75 use crate::IntoArray;
76 #[expect(deprecated)]
77 use crate::ToCanonical as _;
78 use crate::VortexSessionExecute;
79 use crate::array_session;
80 use crate::arrays::ConstantArray;
81 use crate::arrays::PrimitiveArray;
82 use crate::assert_arrays_eq;
83 use crate::compute::conformance::take::test_take_conformance;
84 use crate::dtype::Nullability;
85 use crate::scalar::Scalar;
86 use crate::validity::Validity;
87
88 #[test]
89 fn take_nullable_indices() {
90 let mut ctx = array_session().create_execution_ctx();
91 let array = ConstantArray::new(42, 10).into_array();
92 let taken = array
93 .take(
94 PrimitiveArray::new(
95 buffer![0, 5, 7],
96 Validity::from_iter(vec![false, true, false]),
97 )
98 .into_array(),
99 )
100 .unwrap();
101 let valid_indices: &[usize] = &[1usize];
102 assert_eq!(
103 &array.dtype().with_nullability(Nullability::Nullable),
104 taken.dtype()
105 );
106 assert_arrays_eq!(
107 #[expect(deprecated)]
108 taken.to_primitive(),
109 PrimitiveArray::new(
110 buffer![42i32, 42, 42],
111 Validity::from_iter([false, true, false])
112 ),
113 &mut ctx
114 );
115 assert_eq!(
116 taken
117 .validity()
118 .unwrap()
119 .execute_mask(taken.len(), &mut array_session().create_execution_ctx())
120 .unwrap()
121 .indices(),
122 AllOr::Some(valid_indices)
123 );
124 }
125
126 #[test]
127 fn take_all_valid_indices() {
128 let mut ctx = array_session().create_execution_ctx();
129 let array = ConstantArray::new(42, 10).into_array();
130 let taken = array
131 .take(PrimitiveArray::new(buffer![0, 5, 7], Validity::AllValid).into_array())
132 .unwrap();
133 assert_eq!(
134 &array.dtype().with_nullability(Nullability::Nullable),
135 taken.dtype()
136 );
137 assert_arrays_eq!(
138 #[expect(deprecated)]
139 taken.to_primitive(),
140 PrimitiveArray::new(buffer![42i32, 42, 42], Validity::AllValid),
141 &mut ctx
142 );
143 assert_eq!(
144 taken
145 .validity()
146 .unwrap()
147 .execute_mask(taken.len(), &mut array_session().create_execution_ctx())
148 .unwrap()
149 .indices(),
150 AllOr::All
151 );
152 }
153
154 #[rstest]
155 #[case(ConstantArray::new(42i32, 5))]
156 #[case(ConstantArray::new(f64::consts::PI, 10))]
157 #[case(ConstantArray::new(Scalar::from("hello"), 3))]
158 #[case(ConstantArray::new(Scalar::null_native::<i64>(), 5))]
159 #[case(ConstantArray::new(true, 1))]
160 fn test_take_constant_conformance(#[case] array: ConstantArray) {
161 test_take_conformance(&array.into_array());
162 }
163}