vortex_array/arrays/constant/compute/
take.rs1use vortex_error::VortexResult;
5use vortex_mask::AllOr;
6
7use crate::ArrayRef;
8use crate::IntoArray;
9use crate::LEGACY_SESSION;
10use crate::VortexSessionExecute;
11use crate::array::ArrayView;
12use crate::arrays::Constant;
13use crate::arrays::ConstantArray;
14use crate::arrays::MaskedArray;
15use crate::arrays::dict::TakeReduce;
16use crate::arrays::dict::TakeReduceAdaptor;
17use crate::optimizer::rules::ParentRuleSet;
18use crate::scalar::Scalar;
19use crate::validity::Validity;
20
21impl TakeReduce for Constant {
22 fn take(array: ArrayView<'_, Constant>, indices: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
23 let mut ctx = LEGACY_SESSION.create_execution_ctx();
24 let result = match indices
25 .validity()?
26 .execute_mask(indices.len(), &mut ctx)?
27 .bit_buffer()
28 {
29 AllOr::All => {
30 let scalar = Scalar::try_new(
31 array
32 .scalar()
33 .dtype()
34 .union_nullability(indices.dtype().nullability()),
35 array.scalar().value().cloned(),
36 )?;
37 ConstantArray::new(scalar, indices.len()).into_array()
38 }
39 AllOr::None => ConstantArray::new(
40 Scalar::null(
41 array
42 .dtype()
43 .union_nullability(indices.dtype().nullability()),
44 ),
45 indices.len(),
46 )
47 .into_array(),
48 AllOr::Some(v) => {
49 let arr = ConstantArray::new(array.scalar().clone(), indices.len()).into_array();
50
51 if array.scalar().is_null() {
52 return Ok(Some(arr));
53 }
54
55 MaskedArray::try_new(arr, Validity::from(v.clone()))?.into_array()
56 }
57 };
58 Ok(Some(result))
59 }
60}
61
62impl Constant {
63 pub const TAKE_RULES: ParentRuleSet<Self> =
64 ParentRuleSet::new(&[ParentRuleSet::lift(&TakeReduceAdaptor::<Self>(Self))]);
65}
66
67#[cfg(test)]
68mod tests {
69 use rstest::rstest;
70 use vortex_buffer::buffer;
71 use vortex_mask::AllOr;
72
73 use crate::IntoArray;
74 use crate::LEGACY_SESSION;
75 #[expect(deprecated)]
76 use crate::ToCanonical as _;
77 use crate::VortexSessionExecute;
78 use crate::arrays::ConstantArray;
79 use crate::arrays::PrimitiveArray;
80 use crate::assert_arrays_eq;
81 use crate::compute::conformance::take::test_take_conformance;
82 use crate::dtype::Nullability;
83 use crate::scalar::Scalar;
84 use crate::validity::Validity;
85
86 #[test]
87 fn take_nullable_indices() {
88 let array = ConstantArray::new(42, 10).into_array();
89 let taken = array
90 .take(
91 PrimitiveArray::new(
92 buffer![0, 5, 7],
93 Validity::from_iter(vec![false, true, false]),
94 )
95 .into_array(),
96 )
97 .unwrap();
98 let valid_indices: &[usize] = &[1usize];
99 assert_eq!(
100 &array.dtype().with_nullability(Nullability::Nullable),
101 taken.dtype()
102 );
103 assert_arrays_eq!(
104 #[expect(deprecated)]
105 taken.to_primitive(),
106 PrimitiveArray::new(
107 buffer![42i32, 42, 42],
108 Validity::from_iter([false, true, false])
109 )
110 );
111 assert_eq!(
112 taken
113 .validity()
114 .unwrap()
115 .execute_mask(taken.len(), &mut LEGACY_SESSION.create_execution_ctx())
116 .unwrap()
117 .indices(),
118 AllOr::Some(valid_indices)
119 );
120 }
121
122 #[test]
123 fn take_all_valid_indices() {
124 let array = ConstantArray::new(42, 10).into_array();
125 let taken = array
126 .take(PrimitiveArray::new(buffer![0, 5, 7], Validity::AllValid).into_array())
127 .unwrap();
128 assert_eq!(
129 &array.dtype().with_nullability(Nullability::Nullable),
130 taken.dtype()
131 );
132 assert_arrays_eq!(
133 #[expect(deprecated)]
134 taken.to_primitive(),
135 PrimitiveArray::new(buffer![42i32, 42, 42], Validity::AllValid)
136 );
137 assert_eq!(
138 taken
139 .validity()
140 .unwrap()
141 .execute_mask(taken.len(), &mut LEGACY_SESSION.create_execution_ctx())
142 .unwrap()
143 .indices(),
144 AllOr::All
145 );
146 }
147
148 #[rstest]
149 #[case(ConstantArray::new(42i32, 5))]
150 #[case(ConstantArray::new(std::f64::consts::PI, 10))]
151 #[case(ConstantArray::new(Scalar::from("hello"), 3))]
152 #[case(ConstantArray::new(Scalar::null_native::<i64>(), 5))]
153 #[case(ConstantArray::new(true, 1))]
154 fn test_take_constant_conformance(#[case] array: ConstantArray) {
155 test_take_conformance(&array.into_array());
156 }
157}