vortex_array/arrays/chunked/compute/
take.rs1use vortex_buffer::BufferMut;
5use vortex_dtype::{DType, PType};
6use vortex_error::VortexResult;
7
8use crate::arrays::chunked::ChunkedArray;
9use crate::arrays::{ChunkedVTable, PrimitiveArray};
10use crate::compute::{TakeKernel, TakeKernelAdapter, cast, take};
11use crate::validity::Validity;
12use crate::{Array, ArrayRef, IntoArray, ToCanonical, register_kernel};
13
14impl TakeKernel for ChunkedVTable {
15 fn take(&self, array: &ChunkedArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
16 let indices = cast(
17 indices,
18 &DType::Primitive(PType::U64, indices.dtype().nullability()),
19 )?
20 .to_primitive()?;
21
22 let nullability = indices.dtype().nullability();
24 let indices_mask = indices.validity_mask()?;
25 let indices = indices.as_slice::<u64>();
26
27 let mut chunks = Vec::new();
28 let mut indices_in_chunk = BufferMut::<u64>::empty();
29 let mut start = 0;
30 let mut stop = 0;
31 let mut prev_chunk_idx = array.find_chunk_idx(indices[0].try_into()?).0;
33 for idx in indices {
34 let idx = usize::try_from(*idx)?;
35 let (chunk_idx, idx_in_chunk) = array.find_chunk_idx(idx);
36
37 if chunk_idx != prev_chunk_idx {
38 let indices_in_chunk_array = PrimitiveArray::new(
40 indices_in_chunk.clone().freeze(),
41 Validity::from_mask(indices_mask.slice(start, stop - start), nullability),
42 );
43 chunks.push(take(
44 array.chunk(prev_chunk_idx),
45 indices_in_chunk_array.as_ref(),
46 )?);
47 indices_in_chunk.clear();
48 start = stop;
49 }
50
51 indices_in_chunk.push(idx_in_chunk as u64);
52 stop += 1;
53 prev_chunk_idx = chunk_idx;
54 }
55
56 if !indices_in_chunk.is_empty() {
57 let indices_in_chunk_array = PrimitiveArray::new(
58 indices_in_chunk.freeze(),
59 Validity::from_mask(indices_mask.slice(start, stop - start), nullability),
60 );
61 chunks.push(take(
62 array.chunk(prev_chunk_idx),
63 indices_in_chunk_array.as_ref(),
64 )?);
65 }
66
67 unsafe {
69 Ok(ChunkedArray::new_unchecked(
70 chunks,
71 array.dtype().clone().union_nullability(nullability),
72 )
73 .into_array())
74 }
75 }
76}
77
78register_kernel!(TakeKernelAdapter(ChunkedVTable).lift());
79
80#[cfg(test)]
81mod test {
82 use vortex_buffer::buffer;
83 use vortex_dtype::{FieldNames, Nullability};
84
85 use crate::IntoArray;
86 use crate::array::Array;
87 use crate::arrays::chunked::ChunkedArray;
88 use crate::arrays::{BoolArray, PrimitiveArray, StructArray};
89 use crate::canonical::ToCanonical;
90 use crate::compute::conformance::take::test_take_conformance;
91 use crate::compute::take;
92 use crate::validity::Validity;
93
94 #[test]
95 fn test_take() {
96 let a = buffer![1i32, 2, 3].into_array();
97 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
98 .unwrap();
99 assert_eq!(arr.nchunks(), 3);
100 assert_eq!(arr.len(), 9);
101 let indices = buffer![0u64, 0, 6, 4].into_array();
102
103 let result = take(arr.as_ref(), indices.as_ref())
104 .unwrap()
105 .to_primitive()
106 .unwrap();
107 assert_eq!(result.as_slice::<i32>(), &[1, 1, 1, 2]);
108 }
109
110 #[test]
111 fn test_take_nullability() {
112 let struct_array =
113 StructArray::try_new(FieldNames::default(), vec![], 100, Validity::NonNullable)
114 .unwrap();
115
116 let arr = ChunkedArray::from_iter(vec![struct_array.to_array(), struct_array.to_array()]);
117
118 let result = take(
119 arr.as_ref(),
120 PrimitiveArray::from_option_iter(vec![Some(0), None, Some(101)]).as_ref(),
121 )
122 .unwrap();
123
124 let expect = StructArray::try_new(
125 FieldNames::default(),
126 vec![],
127 3,
128 Validity::Array(BoolArray::from_iter(vec![true, false, true]).to_array()),
129 )
130 .unwrap();
131 assert_eq!(result.dtype(), expect.dtype());
132 assert_eq!(result.scalar_at(0), expect.scalar_at(0));
133 assert_eq!(result.scalar_at(1), expect.scalar_at(1));
134 assert_eq!(result.scalar_at(2), expect.scalar_at(2));
135 }
136
137 #[test]
138 fn test_empty_take() {
139 let a = buffer![1i32, 2, 3].into_array();
140 let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
141 .unwrap();
142 assert_eq!(arr.nchunks(), 3);
143 assert_eq!(arr.len(), 9);
144
145 let indices = PrimitiveArray::empty::<u64>(Nullability::NonNullable);
146 let result = take(arr.as_ref(), indices.as_ref())
147 .unwrap()
148 .to_primitive()
149 .unwrap();
150
151 assert!(result.is_empty());
152 assert_eq!(result.dtype(), arr.dtype());
153 assert!(result.as_slice::<i32>().is_empty());
154 }
155
156 #[test]
157 fn test_take_chunked_conformance() {
158 let a = buffer![1i32, 2, 3].into_array();
159 let b = buffer![4i32, 5].into_array();
160 let arr = ChunkedArray::try_new(
161 vec![a, b],
162 PrimitiveArray::empty::<i32>(Nullability::NonNullable)
163 .dtype()
164 .clone(),
165 )
166 .unwrap();
167 test_take_conformance(arr.as_ref());
168
169 let a = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]);
171 let b = PrimitiveArray::from_option_iter([Some(4i32), Some(5)]);
172 let dtype = a.dtype().clone();
173 let arr = ChunkedArray::try_new(vec![a.into_array(), b.into_array()], dtype).unwrap();
174 test_take_conformance(arr.as_ref());
175
176 let chunk = buffer![10i32, 20, 30, 40, 50].into_array();
178 let arr = ChunkedArray::try_new(
179 vec![chunk.clone(), chunk.clone(), chunk.clone()],
180 chunk.dtype().clone(),
181 )
182 .unwrap();
183 test_take_conformance(arr.as_ref());
184 }
185}