1use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::dict::TakeExecute;
9use vortex_error::VortexResult;
10
11use crate::ConstantArray;
12use crate::Sparse;
13impl TakeExecute for Sparse {
14 fn take(
15 array: ArrayView<'_, Self>,
16 indices: &ArrayRef,
17 ctx: &mut ExecutionCtx,
18 ) -> VortexResult<Option<ArrayRef>> {
19 let patches_take = if array.fill_scalar().is_null() {
20 array.patches().take(indices, ctx)?
21 } else {
22 array.patches().take_with_nulls(indices, ctx)?
23 };
24
25 let Some(new_patches) = patches_take else {
26 let result_fill_scalar = array.fill_scalar().cast(
27 &array
28 .dtype()
29 .union_nullability(indices.dtype().nullability()),
30 )?;
31 return Ok(Some(
32 ConstantArray::new(result_fill_scalar, indices.len()).into_array(),
33 ));
34 };
35
36 if new_patches.array_len() == new_patches.values().len() {
38 return Ok(Some(new_patches.into_values()));
39 }
40
41 Ok(Some(
42 Sparse::try_new_from_patches(
43 new_patches,
44 array.fill_scalar().cast(
45 &array
46 .dtype()
47 .union_nullability(indices.dtype().nullability()),
48 )?,
49 )?
50 .into_array(),
51 ))
52 }
53}
54
55#[cfg(test)]
56mod test {
57 use rstest::rstest;
58 use vortex_array::ArrayRef;
59 use vortex_array::IntoArray;
60 use vortex_array::arrays::ConstantArray;
61 use vortex_array::arrays::PrimitiveArray;
62 use vortex_array::assert_arrays_eq;
63 use vortex_array::dtype::Nullability;
64 use vortex_array::scalar::Scalar;
65 use vortex_array::validity::Validity;
66 use vortex_buffer::buffer;
67
68 use crate::Sparse;
69 use crate::SparseArray;
70
71 fn test_array_fill_value() -> Scalar {
72 Scalar::null_native::<f64>()
74 }
75
76 fn sparse_array() -> ArrayRef {
77 Sparse::try_new(
78 buffer![0u64, 37, 47, 99].into_array(),
79 PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
80 100,
81 test_array_fill_value(),
82 )
83 .unwrap()
84 .into_array()
85 }
86
87 #[test]
88 fn take_with_non_zero_offset() {
89 let sparse = sparse_array();
90 let sparse = sparse.slice(30..40).unwrap();
91 let taken = sparse.take(buffer![6, 7, 8].into_array()).unwrap();
92 let expected = PrimitiveArray::from_option_iter([Option::<f64>::None, Some(0.47), None]);
93 assert_arrays_eq!(taken, expected.into_array());
94 }
95
96 #[test]
97 fn sparse_take() {
98 let sparse = sparse_array();
99 let taken = sparse.take(buffer![0, 47, 47, 0, 99].into_array()).unwrap();
100 let expected = PrimitiveArray::from_option_iter([
101 Some(1.23f64),
102 Some(9.99),
103 Some(9.99),
104 Some(1.23),
105 Some(3.5),
106 ]);
107 assert_arrays_eq!(taken, expected.into_array());
108 }
109
110 #[test]
111 fn nonexistent_take() {
112 let sparse = sparse_array();
113 let taken = sparse.take(buffer![69].into_array()).unwrap();
114 let expected = ConstantArray::new(test_array_fill_value(), 1).into_array();
115 assert_arrays_eq!(taken, expected);
116 }
117
118 #[test]
119 fn ordered_take() {
120 let sparse = sparse_array();
121 let taken = sparse.take(buffer![69, 37].into_array()).unwrap();
123 let expected = PrimitiveArray::from_option_iter([Option::<f64>::None, Some(0.47f64)]);
125 assert_arrays_eq!(taken, expected.into_array());
126 }
127
128 #[test]
129 fn nullable_take() {
130 let arr = Sparse::try_new(
131 buffer![1u32].into_array(),
132 buffer![10].into_array(),
133 10,
134 Scalar::primitive(1, Nullability::NonNullable),
135 )
136 .unwrap();
137
138 let taken = arr
139 .take(
140 PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
141 .into_array(),
142 )
143 .unwrap();
144
145 let expected = PrimitiveArray::from_option_iter([Some(1), Some(10), Option::<i32>::None]);
146 assert_arrays_eq!(taken, expected.into_array());
147 }
148
149 #[test]
150 fn nullable_take_with_many_patches() {
151 let arr = Sparse::try_new(
152 buffer![1u32, 3, 7, 8, 9].into_array(),
153 buffer![10, 8, 3, 2, 1].into_array(),
154 10,
155 Scalar::primitive(1, Nullability::NonNullable),
156 )
157 .unwrap();
158
159 let taken = arr
160 .take(
161 PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
162 .into_array(),
163 )
164 .unwrap();
165
166 let expected = PrimitiveArray::from_option_iter([Some(1), Some(10), Option::<i32>::None]);
167 assert_arrays_eq!(taken, expected.into_array());
168 }
169
170 #[rstest]
171 #[case(Sparse::try_new(
172 buffer![0u64, 37, 47, 99].into_array(),
173 PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
174 100,
175 Scalar::null_native::<f64>(),
176 ).unwrap())]
177 #[case(Sparse::try_new(
178 buffer![1u32, 3, 7, 8, 9].into_array(),
179 buffer![10, 8, 3, 2, 1].into_array(),
180 10,
181 Scalar::from(0i32),
182 ).unwrap())]
183 #[case({
184 let nullable_values = PrimitiveArray::from_option_iter([Some(100i64), None, Some(300)]);
185 Sparse::try_new(
186 buffer![2u64, 4, 6].into_array(),
187 nullable_values.into_array(),
188 10,
189 Scalar::null_native::<i64>(),
190 ).unwrap()
191 })]
192 #[case(Sparse::try_new(
193 buffer![5u64].into_array(),
194 buffer![999i32].into_array(),
195 20,
196 Scalar::from(-1i32),
197 ).unwrap())]
198 fn test_take_sparse_conformance(#[case] sparse: SparseArray) {
199 use vortex_array::compute::conformance::take::test_take_conformance;
200 test_take_conformance(&sparse.into_array());
201 }
202}