1use vortex_array::Array;
5use vortex_array::ArrayRef;
6use vortex_array::IntoArray;
7use vortex_array::arrays::ConstantArray;
8use vortex_array::compute::TakeKernel;
9use vortex_array::compute::TakeKernelAdapter;
10use vortex_array::register_kernel;
11use vortex_error::VortexResult;
12
13use crate::SparseArray;
14use crate::SparseVTable;
15
16impl TakeKernel for SparseVTable {
17 fn take(&self, array: &SparseArray, take_indices: &dyn Array) -> VortexResult<ArrayRef> {
18 let patches_take = if array.fill_scalar().is_null() {
19 array.patches().take(take_indices)?
20 } else {
21 array.patches().take_with_nulls(take_indices)?
22 };
23
24 let Some(new_patches) = patches_take else {
25 let result_fill_scalar = array.fill_scalar().cast(
26 &array
27 .dtype()
28 .union_nullability(take_indices.dtype().nullability()),
29 )?;
30 return Ok(ConstantArray::new(result_fill_scalar, take_indices.len()).into_array());
31 };
32
33 if new_patches.array_len() == new_patches.values().len() {
35 return Ok(new_patches.into_values());
36 }
37
38 Ok(SparseArray::try_new_from_patches(
39 new_patches,
40 array.fill_scalar().cast(
41 &array
42 .dtype()
43 .union_nullability(take_indices.dtype().nullability()),
44 )?,
45 )?
46 .into_array())
47 }
48}
49
50register_kernel!(TakeKernelAdapter(SparseVTable).lift());
51
52#[cfg(test)]
53mod test {
54 use rstest::rstest;
55 use vortex_array::Array;
56 use vortex_array::ArrayRef;
57 use vortex_array::IntoArray;
58 use vortex_array::ToCanonical;
59 use vortex_array::arrays::ConstantArray;
60 use vortex_array::arrays::PrimitiveArray;
61 use vortex_array::assert_arrays_eq;
62 use vortex_array::compute::take;
63 use vortex_array::validity::Validity;
64 use vortex_buffer::buffer;
65 use vortex_dtype::Nullability;
66 use vortex_scalar::Scalar;
67
68 use crate::SparseArray;
69 use crate::SparseVTable;
70
71 fn test_array_fill_value() -> Scalar {
72 Scalar::null_typed::<f64>()
74 }
75
76 fn sparse_array() -> ArrayRef {
77 SparseArray::try_new(
78 buffer![0u64, 37, 47, 99].into_array(),
79 PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
80 100,
81 test_array_fill_value(),
82 )
83 .unwrap()
84 .into_array()
85 }
86
87 #[test]
88 fn take_with_non_zero_offset() {
89 let sparse = sparse_array();
90 let sparse = sparse.slice(30..40);
91 let taken = take(&sparse, &buffer![6, 7, 8].into_array()).unwrap();
92 let expected = PrimitiveArray::from_option_iter([Option::<f64>::None, Some(0.47), None]);
93 assert_arrays_eq!(taken, expected.to_array());
94 }
95
96 #[test]
97 fn sparse_take() {
98 let sparse = sparse_array();
99 let taken = take(&sparse, &buffer![0, 47, 47, 0, 99].into_array()).unwrap();
100 let expected = PrimitiveArray::from_option_iter([
101 Some(1.23f64),
102 Some(9.99),
103 Some(9.99),
104 Some(1.23),
105 Some(3.5),
106 ]);
107 assert_arrays_eq!(taken, expected.to_array());
108 }
109
110 #[test]
111 fn nonexistent_take() {
112 let sparse = sparse_array();
113 let taken = take(&sparse, &buffer![69].into_array()).unwrap();
114 let expected = ConstantArray::new(test_array_fill_value(), 1).into_array();
115 assert_arrays_eq!(taken, expected);
116 }
117
118 #[test]
119 fn ordered_take() {
120 let sparse = sparse_array();
121 let taken_arr = take(&sparse, &buffer![69, 37].into_array()).unwrap();
122 let taken = taken_arr.as_::<SparseVTable>();
123
124 assert_eq!(
125 taken.patches().indices().to_primitive().as_slice::<u64>(),
126 [1]
127 );
128 assert_eq!(
129 taken.patches().values().to_primitive().as_slice::<f64>(),
130 [0.47f64]
131 );
132 assert_eq!(taken.len(), 2);
133 }
134
135 #[test]
136 fn nullable_take() {
137 let arr = SparseArray::try_new(
138 buffer![1u32].into_array(),
139 buffer![10].into_array(),
140 10,
141 Scalar::primitive(1, Nullability::NonNullable),
142 )
143 .unwrap();
144
145 let taken = take(
146 arr.as_ref(),
147 PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
148 .as_ref(),
149 )
150 .unwrap();
151
152 let expected = PrimitiveArray::from_option_iter([Some(1), Some(10), Option::<i32>::None]);
153 assert_arrays_eq!(taken, expected.to_array());
154 }
155
156 #[test]
157 fn nullable_take_with_many_patches() {
158 let arr = SparseArray::try_new(
159 buffer![1u32, 3, 7, 8, 9].into_array(),
160 buffer![10, 8, 3, 2, 1].into_array(),
161 10,
162 Scalar::primitive(1, Nullability::NonNullable),
163 )
164 .unwrap();
165
166 let taken = take(
167 arr.as_ref(),
168 PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
169 .as_ref(),
170 )
171 .unwrap();
172
173 let expected = PrimitiveArray::from_option_iter([Some(1), Some(10), Option::<i32>::None]);
174 assert_arrays_eq!(taken, expected.to_array());
175 }
176
177 #[rstest]
178 #[case(SparseArray::try_new(
179 buffer![0u64, 37, 47, 99].into_array(),
180 PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
181 100,
182 Scalar::null_typed::<f64>(),
183 ).unwrap())]
184 #[case(SparseArray::try_new(
185 buffer![1u32, 3, 7, 8, 9].into_array(),
186 buffer![10, 8, 3, 2, 1].into_array(),
187 10,
188 Scalar::from(0i32),
189 ).unwrap())]
190 #[case({
191 let nullable_values = PrimitiveArray::from_option_iter([Some(100i64), None, Some(300)]);
192 SparseArray::try_new(
193 buffer![2u64, 4, 6].into_array(),
194 nullable_values.into_array(),
195 10,
196 Scalar::null_typed::<i64>(),
197 ).unwrap()
198 })]
199 #[case(SparseArray::try_new(
200 buffer![5u64].into_array(),
201 buffer![999i32].into_array(),
202 20,
203 Scalar::from(-1i32),
204 ).unwrap())]
205 fn test_take_sparse_conformance(#[case] sparse: SparseArray) {
206 use vortex_array::compute::conformance::take::test_take_conformance;
207 test_take_conformance(sparse.as_ref());
208 }
209}