1use vortex_array::arrays::ConstantArray;
5use vortex_array::compute::{TakeKernel, TakeKernelAdapter};
6use vortex_array::{Array, ArrayRef, IntoArray, register_kernel};
7use vortex_error::VortexResult;
8
9use crate::{SparseArray, SparseVTable};
10
11impl TakeKernel for SparseVTable {
12 fn take(&self, array: &SparseArray, take_indices: &dyn Array) -> VortexResult<ArrayRef> {
13 let patches_take = if array.fill_scalar().is_null() {
14 array.patches().take(take_indices)?
15 } else {
16 array.patches().take_with_nulls(take_indices)?
17 };
18
19 let Some(new_patches) = patches_take else {
20 let result_fill_scalar = array.fill_scalar().cast(
21 &array
22 .dtype()
23 .union_nullability(take_indices.dtype().nullability()),
24 )?;
25 return Ok(ConstantArray::new(result_fill_scalar, take_indices.len()).into_array());
26 };
27
28 if new_patches.array_len() == new_patches.values().len() {
30 return Ok(new_patches.into_values());
31 }
32
33 Ok(SparseArray::try_new_from_patches(
34 new_patches,
35 array.fill_scalar().cast(
36 &array
37 .dtype()
38 .union_nullability(take_indices.dtype().nullability()),
39 )?,
40 )?
41 .into_array())
42 }
43}
44
45register_kernel!(TakeKernelAdapter(SparseVTable).lift());
46
47#[cfg(test)]
48mod test {
49 use vortex_array::arrays::PrimitiveArray;
50 use vortex_array::compute::take;
51 use vortex_array::validity::Validity;
52 use vortex_array::{Array, ArrayRef, IntoArray, ToCanonical};
53 use vortex_buffer::buffer;
54 use vortex_dtype::PType::I32;
55 use vortex_dtype::{DType, Nullability};
56 use vortex_scalar::Scalar;
57
58 use crate::{SparseArray, SparseVTable};
59
60 fn test_array_fill_value() -> Scalar {
61 Scalar::null_typed::<f64>()
63 }
64
65 fn sparse_array() -> ArrayRef {
66 SparseArray::try_new(
67 buffer![0u64, 37, 47, 99].into_array(),
68 PrimitiveArray::new(buffer![1.23f64, 0.47, 9.99, 3.5], Validity::AllValid).into_array(),
69 100,
70 test_array_fill_value(),
71 )
72 .unwrap()
73 .into_array()
74 }
75
76 #[test]
77 fn take_with_non_zero_offset() {
78 let sparse = sparse_array();
79 let sparse = sparse.slice(30, 40).unwrap();
80 let sparse = take(&sparse, &buffer![6, 7, 8].into_array()).unwrap();
81 assert_eq!(sparse.scalar_at(0).unwrap(), test_array_fill_value());
82 assert_eq!(sparse.scalar_at(1).unwrap(), Scalar::from(Some(0.47)));
83 assert_eq!(sparse.scalar_at(2).unwrap(), test_array_fill_value());
84 }
85
86 #[test]
87 fn sparse_take() {
88 let sparse = sparse_array();
89 let prim = take(&sparse, &buffer![0, 47, 47, 0, 99].into_array())
90 .unwrap()
91 .to_primitive()
92 .unwrap();
93 assert_eq!(prim.as_slice::<f64>(), [1.23f64, 9.99, 9.99, 1.23, 3.5]);
94 }
95
96 #[test]
97 fn nonexistent_take() {
98 let sparse = sparse_array();
99 let taken = take(&sparse, &buffer![69].into_array()).unwrap();
100 assert_eq!(taken.len(), 1);
101 assert_eq!(taken.scalar_at(0).unwrap(), test_array_fill_value());
102 }
103
104 #[test]
105 fn ordered_take() {
106 let sparse = sparse_array();
107 let taken_arr = take(&sparse, &buffer![69, 37].into_array()).unwrap();
108 let taken = taken_arr.as_::<SparseVTable>();
109
110 assert_eq!(
111 taken
112 .patches()
113 .indices()
114 .to_primitive()
115 .unwrap()
116 .as_slice::<u64>(),
117 [1]
118 );
119 assert_eq!(
120 taken
121 .patches()
122 .values()
123 .to_primitive()
124 .unwrap()
125 .as_slice::<f64>(),
126 [0.47f64]
127 );
128 assert_eq!(taken.len(), 2);
129 }
130
131 #[test]
132 fn nullable_take() {
133 let arr = SparseArray::try_new(
134 buffer![1u32].into_array(),
135 buffer![10].into_array(),
136 10,
137 Scalar::primitive(1, Nullability::NonNullable),
138 )
139 .unwrap();
140
141 let taken = take(
142 arr.as_ref(),
143 PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
144 .as_ref(),
145 )
146 .unwrap();
147
148 assert_eq!(
149 taken.scalar_at(0).unwrap(),
150 Scalar::primitive(1, Nullability::Nullable)
151 );
152 assert_eq!(
153 taken.scalar_at(1).unwrap(),
154 Scalar::primitive(10, Nullability::Nullable)
155 );
156 assert_eq!(
157 taken.scalar_at(2).unwrap(),
158 Scalar::null(DType::Primitive(I32, Nullability::Nullable))
159 );
160 }
161
162 #[test]
163 fn nullable_take_with_many_patches() {
164 let arr = SparseArray::try_new(
165 buffer![1u32, 3, 7, 8, 9].into_array(),
166 buffer![10, 8, 3, 2, 1].into_array(),
167 10,
168 Scalar::primitive(1, Nullability::NonNullable),
169 )
170 .unwrap();
171
172 let taken = take(
173 arr.as_ref(),
174 PrimitiveArray::from_option_iter([Some(2u32), Some(1u32), Option::<u32>::None])
175 .as_ref(),
176 )
177 .unwrap();
178
179 assert_eq!(
180 taken.scalar_at(0).unwrap(),
181 Scalar::primitive(1, Nullability::Nullable)
182 );
183 assert_eq!(
184 taken.scalar_at(1).unwrap(),
185 Scalar::primitive(10, Nullability::Nullable)
186 );
187 assert_eq!(
188 taken.scalar_at(2).unwrap(),
189 Scalar::null(DType::Primitive(I32, Nullability::Nullable))
190 );
191 }
192}