1use vortex_array::ArrayRef;
5use vortex_array::ArrayView;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::dict::TakeExecute;
9use vortex_error::VortexResult;
10
11use crate::ConstantArray;
12use crate::Sparse;
13use crate::SparseExt as _;
14impl TakeExecute for Sparse {
15 fn take(
16 array: ArrayView<'_, Self>,
17 indices: &ArrayRef,
18 ctx: &mut ExecutionCtx,
19 ) -> VortexResult<Option<ArrayRef>> {
20 let patches_take = if array.fill_scalar().is_null() {
21 array.patches().take(indices, ctx)?
22 } else {
23 array.patches().take_with_nulls(indices, ctx)?
24 };
25
26 let Some(new_patches) = patches_take else {
27 let result_fill_scalar = array.fill_scalar().cast(
28 &array
29 .dtype()
30 .union_nullability(indices.dtype().nullability()),
31 )?;
32 return Ok(Some(
33 ConstantArray::new(result_fill_scalar, indices.len()).into_array(),
34 ));
35 };
36
37 if new_patches.array_len() == new_patches.values().len() {
39 return Ok(Some(new_patches.into_values()));
40 }
41
42 Ok(Some(
43 Sparse::try_new_from_patches(
44 new_patches,
45 array.fill_scalar().cast(
46 &array
47 .dtype()
48 .union_nullability(indices.dtype().nullability()),
49 )?,
50 )?
51 .into_array(),
52 ))
53 }
54}
55
56#[cfg(test)]
57mod test {
58 use rstest::rstest;
59 use vortex_array::ArrayRef;
60 use vortex_array::IntoArray;
61 use vortex_array::arrays::ConstantArray;
62 use vortex_array::arrays::PrimitiveArray;
63 use vortex_array::assert_arrays_eq;
64 use vortex_array::dtype::Nullability;
65 use vortex_array::scalar::Scalar;
66 use vortex_array::validity::Validity;
67 use vortex_buffer::buffer;
68
69 use crate::Sparse;
70 use crate::SparseArray;
71
72 fn test_array_fill_value() -> Scalar {
73 Scalar::null_native::<f64>()
75 }
76
77 fn sparse_array() -> ArrayRef {
78 Sparse::try_new(
79 buffer![0u64, 37, 47, 99].into_array(),
80 PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
81 100,
82 test_array_fill_value(),
83 )
84 .unwrap()
85 .into_array()
86 }
87
88 #[test]
89 fn take_with_non_zero_offset() {
90 let sparse = sparse_array();
91 let sparse = sparse.slice(30..40).unwrap();
92 let taken = sparse.take(buffer![6, 7, 8].into_array()).unwrap();
93 let expected = PrimitiveArray::from_option_iter([Option::<f64>::None, Some(0.47), None]);
94 assert_arrays_eq!(taken, expected.into_array());
95 }
96
97 #[test]
98 fn sparse_take() {
99 let sparse = sparse_array();
100 let taken = sparse.take(buffer![0, 47, 47, 0, 99].into_array()).unwrap();
101 let expected = PrimitiveArray::from_option_iter([
102 Some(1.23f64),
103 Some(9.99),
104 Some(9.99),
105 Some(1.23),
106 Some(3.5),
107 ]);
108 assert_arrays_eq!(taken, expected.into_array());
109 }
110
111 #[test]
112 fn nonexistent_take() {
113 let sparse = sparse_array();
114 let taken = sparse.take(buffer![69].into_array()).unwrap();
115 let expected = ConstantArray::new(test_array_fill_value(), 1).into_array();
116 assert_arrays_eq!(taken, expected);
117 }
118
119 #[test]
120 fn ordered_take() {
121 let sparse = sparse_array();
122 let taken = sparse.take(buffer![69, 37].into_array()).unwrap();
124 let expected = PrimitiveArray::from_option_iter([Option::<f64>::None, Some(0.47f64)]);
126 assert_arrays_eq!(taken, expected.into_array());
127 }
128
129 #[test]
130 fn nullable_take() {
131 let arr = Sparse::try_new(
132 buffer![1u32].into_array(),
133 buffer![10].into_array(),
134 10,
135 Scalar::primitive(1, Nullability::NonNullable),
136 )
137 .unwrap();
138
139 let taken = arr
140 .take(
141 PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
142 .into_array(),
143 )
144 .unwrap();
145
146 let expected = PrimitiveArray::from_option_iter([Some(1), Some(10), Option::<i32>::None]);
147 assert_arrays_eq!(taken, expected.into_array());
148 }
149
150 #[test]
151 fn nullable_take_with_many_patches() {
152 let arr = Sparse::try_new(
153 buffer![1u32, 3, 7, 8, 9].into_array(),
154 buffer![10, 8, 3, 2, 1].into_array(),
155 10,
156 Scalar::primitive(1, Nullability::NonNullable),
157 )
158 .unwrap();
159
160 let taken = arr
161 .take(
162 PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
163 .into_array(),
164 )
165 .unwrap();
166
167 let expected = PrimitiveArray::from_option_iter([Some(1), Some(10), Option::<i32>::None]);
168 assert_arrays_eq!(taken, expected.into_array());
169 }
170
171 #[rstest]
172 #[case(Sparse::try_new(
173 buffer![0u64, 37, 47, 99].into_array(),
174 PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
175 100,
176 Scalar::null_native::<f64>(),
177 ).unwrap())]
178 #[case(Sparse::try_new(
179 buffer![1u32, 3, 7, 8, 9].into_array(),
180 buffer![10, 8, 3, 2, 1].into_array(),
181 10,
182 Scalar::from(0i32),
183 ).unwrap())]
184 #[case({
185 let nullable_values = PrimitiveArray::from_option_iter([Some(100i64), None, Some(300)]);
186 Sparse::try_new(
187 buffer![2u64, 4, 6].into_array(),
188 nullable_values.into_array(),
189 10,
190 Scalar::null_native::<i64>(),
191 ).unwrap()
192 })]
193 #[case(Sparse::try_new(
194 buffer![5u64].into_array(),
195 buffer![999i32].into_array(),
196 20,
197 Scalar::from(-1i32),
198 ).unwrap())]
199 fn test_take_sparse_conformance(#[case] sparse: SparseArray) {
200 use vortex_array::compute::conformance::take::test_take_conformance;
201 test_take_conformance(&sparse.into_array());
202 }
203}