vortex_array/arrays/constant/compute/
take.rs1use vortex_error::VortexResult;
5use vortex_mask::AllOr;
6
7use crate::ArrayRef;
8use crate::IntoArray;
9use crate::LEGACY_SESSION;
10use crate::VortexSessionExecute;
11use crate::array::ArrayView;
12use crate::arrays::Constant;
13use crate::arrays::ConstantArray;
14use crate::arrays::MaskedArray;
15use crate::arrays::dict::TakeReduce;
16use crate::arrays::dict::TakeReduceAdaptor;
17use crate::optimizer::rules::ParentRuleSet;
18use crate::scalar::Scalar;
19use crate::validity::Validity;
20
21impl TakeReduce for Constant {
22 fn take(array: ArrayView<'_, Constant>, indices: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
23 let mut ctx = LEGACY_SESSION.create_execution_ctx();
24 let result = match indices
25 .validity()?
26 .execute_mask(indices.len(), &mut ctx)?
27 .bit_buffer()
28 {
29 AllOr::All => {
30 let scalar = Scalar::try_new(
31 array
32 .scalar()
33 .dtype()
34 .union_nullability(indices.dtype().nullability()),
35 array.scalar().value().cloned(),
36 )?;
37 ConstantArray::new(scalar, indices.len()).into_array()
38 }
39 AllOr::None => ConstantArray::new(
40 Scalar::null(
41 array
42 .dtype()
43 .union_nullability(indices.dtype().nullability()),
44 ),
45 indices.len(),
46 )
47 .into_array(),
48 AllOr::Some(v) => {
49 let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
50
51 if array.scalar().is_null() {
52 return Ok(Some(arr));
53 }
54
55 MaskedArray::try_new(arr, Validity::from(v.clone()))?.into_array()
56 }
57 };
58 Ok(Some(result))
59 }
60}
61
62impl Constant {
63 pub const TAKE_RULES: ParentRuleSet<Self> =
64 ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::<Self>(Self))]);
65}
66
67#[cfg(test)]
68mod tests {
69 use std::f64;
70
71 use rstest::rstest;
72 use vortex_buffer::buffer;
73 use vortex_mask::AllOr;
74
75 use crate::IntoArray;
76 use crate::LEGACY_SESSION;
77 #[expect(deprecated)]
78 use crate::ToCanonical as _;
79 use crate::VortexSessionExecute;
80 use crate::arrays::ConstantArray;
81 use crate::arrays::PrimitiveArray;
82 use crate::assert_arrays_eq;
83 use crate::compute::conformance::take::test_take_conformance;
84 use crate::dtype::Nullability;
85 use crate::scalar::Scalar;
86 use crate::validity::Validity;
87
88 #[test]
89 fn take_nullable_indices() {
90 let array = ConstantArray::new(42, 10).into_array();
91 let taken = array
92 .take(
93 PrimitiveArray::new(
94 buffer![0, 5, 7],
95 Validity::from_iter(vec![false, true, false]),
96 )
97 .into_array(),
98 )
99 .unwrap();
100 let valid_indices: &[usize] = &[1usize];
101 assert_eq!(
102 &array.dtype().with_nullability(Nullability::Nullable),
103 taken.dtype()
104 );
105 assert_arrays_eq!(
106 #[expect(deprecated)]
107 taken.to_primitive(),
108 PrimitiveArray::new(
109 buffer![42i32, 42, 42],
110 Validity::from_iter([false, true, false])
111 )
112 );
113 assert_eq!(
114 taken
115 .validity()
116 .unwrap()
117 .execute_mask(taken.len(), &mut LEGACY_SESSION.create_execution_ctx())
118 .unwrap()
119 .indices(),
120 AllOr::Some(valid_indices)
121 );
122 }
123
124 #[test]
125 fn take_all_valid_indices() {
126 let array = ConstantArray::new(42, 10).into_array();
127 let taken = array
128 .take(PrimitiveArray::new(buffer![0, 5, 7], Validity::AllValid).into_array())
129 .unwrap();
130 assert_eq!(
131 &array.dtype().with_nullability(Nullability::Nullable),
132 taken.dtype()
133 );
134 assert_arrays_eq!(
135 #[expect(deprecated)]
136 taken.to_primitive(),
137 PrimitiveArray::new(buffer![42i32, 42, 42], Validity::AllValid)
138 );
139 assert_eq!(
140 taken
141 .validity()
142 .unwrap()
143 .execute_mask(taken.len(), &mut LEGACY_SESSION.create_execution_ctx())
144 .unwrap()
145 .indices(),
146 AllOr::All
147 );
148 }
149
150 #[rstest]
151 #[case(ConstantArray::new(42i32, 5))]
152 #[case(ConstantArray::new(f64::consts::PI, 10))]
153 #[case(ConstantArray::new(Scalar::from("hello"), 3))]
154 #[case(ConstantArray::new(Scalar::null_native::<i64>(), 5))]
155 #[case(ConstantArray::new(true, 1))]
156 fn test_take_constant_conformance(#[case] array: ConstantArray) {
157 test_take_conformance(&array.into_array());
158 }
159}