1use vortex_array::arrays::ConstantArray;
5use vortex_array::compute::{TakeKernel, TakeKernelAdapter};
6use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
7use vortex_error::VortexResult;
8
9use crate::{SparseArray, SparseVTable};
10
11impl TakeKernel for SparseVTable {
12 fn take(&self, array: &SparseArray, take_indices: &dyn Array) -> VortexResult<ArrayRef> {
13 let patches_take = if array.fill_scalar().is_null() {
14 array.patches().take(take_indices)?
15 } else {
16 array.patches().take_with_nulls(take_indices)?
17 };
18
19 let Some(new_patches) = patches_take else {
20 let result_fill_scalar = array.fill_scalar().cast(
21 &array
22 .dtype()
23 .union_nullability(take_indices.dtype().nullability()),
24 )?;
25 return Ok(ConstantArray::new(result_fill_scalar, take_indices.len()).into_array());
26 };
27
28 if new_patches.array_len() == new_patches.values().len() {
30 return Ok(new_patches.into_values());
31 }
32
33 Ok(SparseArray::try_new_from_patches(
34 new_patches,
35 array.fill_scalar().cast(
36 &array
37 .dtype()
38 .union_nullability(take_indices.dtype().nullability()),
39 )?,
40 )?
41 .into_array())
42 }
43}
44
45register_kernel!(TakeKernelAdapter(SparseVTable).lift());
46
47#[cfg(test)]
48mod test {
49 use rstest::rstest;
50 use vortex_array::arrays::{ConstantArray, PrimitiveArray};
51 use vortex_array::compute::take;
52 use vortex_array::validity::Validity;
53 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical, assert_arrays_eq};
54 use vortex_buffer::buffer;
55 use vortex_dtype::Nullability;
56 use vortex_scalar::Scalar;
57
58 use crate::{SparseArray, SparseVTable};
59
60 fn test_array_fill_value() -> Scalar {
61 Scalar::null_typed::<f64>()
63 }
64
65 fn sparse_array() -> ArrayRef {
66 SparseArray::try_new(
67 buffer![0u64, 37, 47, 99].into_array(),
68 PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
69 100,
70 test_array_fill_value(),
71 )
72 .unwrap()
73 .into_array()
74 }
75
76 #[test]
77 fn take_with_non_zero_offset() {
78 let sparse = sparse_array();
79 let sparse = sparse.slice(30..40);
80 let taken = take(&sparse, &buffer![6, 7, 8].into_array()).unwrap();
81 let expected = PrimitiveArray::from_option_iter([Option::<f64>::None, Some(0.47), None]);
82 assert_arrays_eq!(taken, expected.to_array());
83 }
84
85 #[test]
86 fn sparse_take() {
87 let sparse = sparse_array();
88 let taken = take(&sparse, &buffer![0, 47, 47, 0, 99].into_array()).unwrap();
89 let expected = PrimitiveArray::from_option_iter([
90 Some(1.23f64),
91 Some(9.99),
92 Some(9.99),
93 Some(1.23),
94 Some(3.5),
95 ]);
96 assert_arrays_eq!(taken, expected.to_array());
97 }
98
99 #[test]
100 fn nonexistent_take() {
101 let sparse = sparse_array();
102 let taken = take(&sparse, &buffer![69].into_array()).unwrap();
103 let expected = ConstantArray::new(test_array_fill_value(), 1).into_array();
104 assert_arrays_eq!(taken, expected);
105 }
106
107 #[test]
108 fn ordered_take() {
109 let sparse = sparse_array();
110 let taken_arr = take(&sparse, &buffer![69, 37].into_array()).unwrap();
111 let taken = taken_arr.as_::<SparseVTable>();
112
113 assert_eq!(
114 taken.patches().indices().to_primitive().as_slice::<u64>(),
115 [1]
116 );
117 assert_eq!(
118 taken.patches().values().to_primitive().as_slice::<f64>(),
119 [0.47f64]
120 );
121 assert_eq!(taken.len(), 2);
122 }
123
124 #[test]
125 fn nullable_take() {
126 let arr = SparseArray::try_new(
127 buffer![1u32].into_array(),
128 buffer![10].into_array(),
129 10,
130 Scalar::primitive(1, Nullability::NonNullable),
131 )
132 .unwrap();
133
134 let taken = take(
135 arr.as_ref(),
136 PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
137 .as_ref(),
138 )
139 .unwrap();
140
141 let expected = PrimitiveArray::from_option_iter([Some(1), Some(10), Option::<i32>::None]);
142 assert_arrays_eq!(taken, expected.to_array());
143 }
144
145 #[test]
146 fn nullable_take_with_many_patches() {
147 let arr = SparseArray::try_new(
148 buffer![1u32, 3, 7, 8, 9].into_array(),
149 buffer![10, 8, 3, 2, 1].into_array(),
150 10,
151 Scalar::primitive(1, Nullability::NonNullable),
152 )
153 .unwrap();
154
155 let taken = take(
156 arr.as_ref(),
157 PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
158 .as_ref(),
159 )
160 .unwrap();
161
162 let expected = PrimitiveArray::from_option_iter([Some(1), Some(10), Option::<i32>::None]);
163 assert_arrays_eq!(taken, expected.to_array());
164 }
165
166 #[rstest]
167 #[case(SparseArray::try_new(
168 buffer![0u64, 37, 47, 99].into_array(),
169 PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
170 100,
171 Scalar::null_typed::<f64>(),
172 ).unwrap())]
173 #[case(SparseArray::try_new(
174 buffer![1u32, 3, 7, 8, 9].into_array(),
175 buffer![10, 8, 3, 2, 1].into_array(),
176 10,
177 Scalar::from(0i32),
178 ).unwrap())]
179 #[case({
180 let nullable_values = PrimitiveArray::from_option_iter([Some(100i64), None, Some(300)]);
181 SparseArray::try_new(
182 buffer![2u64, 4, 6].into_array(),
183 nullable_values.into_array(),
184 10,
185 Scalar::null_typed::<i64>(),
186 ).unwrap()
187 })]
188 #[case(SparseArray::try_new(
189 buffer![5u64].into_array(),
190 buffer![999i32].into_array(),
191 20,
192 Scalar::from(-1i32),
193 ).unwrap())]
194 fn test_take_sparse_conformance(#[case] sparse: SparseArray) {
195 use vortex_array::compute::conformance::take::test_take_conformance;
196 test_take_conformance(sparse.as_ref());
197 }
198}