1use vortex_array::Array;
5use vortex_array::ArrayRef;
6use vortex_array::ExecutionCtx;
7use vortex_array::IntoArray;
8use vortex_array::arrays::ConstantArray;
9use vortex_array::arrays::TakeExecute;
10use vortex_error::VortexResult;
11
12use crate::SparseArray;
13use crate::SparseVTable;
14
15impl TakeExecute for SparseVTable {
16 fn take(
17 array: &SparseArray,
18 indices: &dyn Array,
19 _ctx: &mut ExecutionCtx,
20 ) -> VortexResult<Option<ArrayRef>> {
21 let patches_take = if array.fill_scalar().is_null() {
22 array.patches().take(indices)?
23 } else {
24 array.patches().take_with_nulls(indices)?
25 };
26
27 let Some(new_patches) = patches_take else {
28 let result_fill_scalar = array.fill_scalar().cast(
29 &array
30 .dtype()
31 .union_nullability(indices.dtype().nullability()),
32 )?;
33 return Ok(Some(
34 ConstantArray::new(result_fill_scalar, indices.len()).into_array(),
35 ));
36 };
37
38 if new_patches.array_len() == new_patches.values().len() {
40 return Ok(Some(new_patches.into_values()));
41 }
42
43 Ok(Some(
44 SparseArray::try_new_from_patches(
45 new_patches,
46 array.fill_scalar().cast(
47 &array
48 .dtype()
49 .union_nullability(indices.dtype().nullability()),
50 )?,
51 )?
52 .into_array(),
53 ))
54 }
55}
56
57#[cfg(test)]
58mod test {
59 use rstest::rstest;
60 use vortex_array::Array;
61 use vortex_array::ArrayRef;
62 use vortex_array::IntoArray;
63 use vortex_array::arrays::ConstantArray;
64 use vortex_array::arrays::PrimitiveArray;
65 use vortex_array::assert_arrays_eq;
66 use vortex_array::scalar::Scalar;
67 use vortex_array::validity::Validity;
68 use vortex_buffer::buffer;
69 use vortex_dtype::Nullability;
70
71 use crate::SparseArray;
72
73 fn test_array_fill_value() -> Scalar {
74 Scalar::null_native::<f64>()
76 }
77
78 fn sparse_array() -> ArrayRef {
79 SparseArray::try_new(
80 buffer![0u64, 37, 47, 99].into_array(),
81 PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
82 100,
83 test_array_fill_value(),
84 )
85 .unwrap()
86 .into_array()
87 }
88
89 #[test]
90 fn take_with_non_zero_offset() {
91 let sparse = sparse_array();
92 let sparse = sparse.slice(30..40).unwrap();
93 let taken = sparse.take(buffer![6, 7, 8].into_array()).unwrap();
94 let expected = PrimitiveArray::from_option_iter([Option::<f64>::None, Some(0.47), None]);
95 assert_arrays_eq!(taken, expected.to_array());
96 }
97
98 #[test]
99 fn sparse_take() {
100 let sparse = sparse_array();
101 let taken = sparse.take(buffer![0, 47, 47, 0, 99].into_array()).unwrap();
102 let expected = PrimitiveArray::from_option_iter([
103 Some(1.23f64),
104 Some(9.99),
105 Some(9.99),
106 Some(1.23),
107 Some(3.5),
108 ]);
109 assert_arrays_eq!(taken, expected.to_array());
110 }
111
112 #[test]
113 fn nonexistent_take() {
114 let sparse = sparse_array();
115 let taken = sparse.take(buffer![69].into_array()).unwrap();
116 let expected = ConstantArray::new(test_array_fill_value(), 1).into_array();
117 assert_arrays_eq!(taken, expected);
118 }
119
120 #[test]
121 fn ordered_take() {
122 let sparse = sparse_array();
123 let taken = sparse.take(buffer![69, 37].into_array()).unwrap();
125 let expected = PrimitiveArray::from_option_iter([Option::<f64>::None, Some(0.47f64)]);
127 assert_arrays_eq!(taken, expected.to_array());
128 }
129
130 #[test]
131 fn nullable_take() {
132 let arr = SparseArray::try_new(
133 buffer![1u32].into_array(),
134 buffer![10].into_array(),
135 10,
136 Scalar::primitive(1, Nullability::NonNullable),
137 )
138 .unwrap();
139
140 let taken = arr
141 .take(
142 PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
143 .to_array(),
144 )
145 .unwrap();
146
147 let expected = PrimitiveArray::from_option_iter([Some(1), Some(10), Option::<i32>::None]);
148 assert_arrays_eq!(taken, expected.to_array());
149 }
150
151 #[test]
152 fn nullable_take_with_many_patches() {
153 let arr = SparseArray::try_new(
154 buffer![1u32, 3, 7, 8, 9].into_array(),
155 buffer![10, 8, 3, 2, 1].into_array(),
156 10,
157 Scalar::primitive(1, Nullability::NonNullable),
158 )
159 .unwrap();
160
161 let taken = arr
162 .take(
163 PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
164 .to_array(),
165 )
166 .unwrap();
167
168 let expected = PrimitiveArray::from_option_iter([Some(1), Some(10), Option::<i32>::None]);
169 assert_arrays_eq!(taken, expected.to_array());
170 }
171
172 #[rstest]
173 #[case(SparseArray::try_new(
174 buffer![0u64, 37, 47, 99].into_array(),
175 PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
176 100,
177 Scalar::null_native::<f64>(),
178 ).unwrap())]
179 #[case(SparseArray::try_new(
180 buffer![1u32, 3, 7, 8, 9].into_array(),
181 buffer![10, 8, 3, 2, 1].into_array(),
182 10,
183 Scalar::from(0i32),
184 ).unwrap())]
185 #[case({
186 let nullable_values = PrimitiveArray::from_option_iter([Some(100i64), None, Some(300)]);
187 SparseArray::try_new(
188 buffer![2u64, 4, 6].into_array(),
189 nullable_values.into_array(),
190 10,
191 Scalar::null_native::<i64>(),
192 ).unwrap()
193 })]
194 #[case(SparseArray::try_new(
195 buffer![5u64].into_array(),
196 buffer![999i32].into_array(),
197 20,
198 Scalar::from(-1i32),
199 ).unwrap())]
200 fn test_take_sparse_conformance(#[case] sparse: SparseArray) {
201 use vortex_array::compute::conformance::take::test_take_conformance;
202 test_take_conformance(sparse.as_ref());
203 }
204}