Skip to main content

vortex_array/arrays/patched/
array.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Display;
5use std::fmt::Formatter;
6use std::ops::Range;
7
8use vortex_buffer::Buffer;
9use vortex_buffer::BufferMut;
10use vortex_error::VortexResult;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_err;
13
14use crate::ArrayRef;
15use crate::Canonical;
16use crate::ExecutionCtx;
17use crate::IntoArray;
18use crate::array::Array;
19use crate::array::ArrayParts;
20use crate::array::TypedArrayRef;
21use crate::array_slots;
22use crate::arrays::Patched;
23use crate::arrays::PrimitiveArray;
24use crate::arrays::patched::TransposedPatches;
25use crate::arrays::patched::patch_lanes;
26use crate::buffer::BufferHandle;
27use crate::dtype::DType;
28use crate::dtype::IntegerPType;
29use crate::dtype::NativePType;
30use crate::dtype::PType;
31use crate::match_each_native_ptype;
32use crate::match_each_unsigned_integer_ptype;
33use crate::patches::Patches;
34use crate::validity::Validity;
35
36#[derive(Debug, Clone)]
37pub struct PatchedData {
38    /// Number of lanes the patch indices and values have been split into. Each of the `n_chunks`
39    /// of 1024 values is split into `n_lanes` lanes horizontally, each lane having 1024 / n_lanes
40    /// values that might be patched.
41    pub(super) n_lanes: usize,
42
43    /// The offset into that first chunk that is considered in bounds.
44    ///
45    /// The patch indices of the first chunk less than `offset` should be skipped, and the offset
46    /// should be subtracted out of the remaining offsets to get their final position in the
47    /// executed array.
48    pub(super) offset: usize,
49}
50
51#[array_slots(Patched)]
52pub struct PatchedSlots {
53    /// The inner array containing the base unpatched values.
54    pub inner: ArrayRef,
55    /// The lane offsets array for locating patches within lanes.
56    pub lane_offsets: ArrayRef,
57    /// The indices of patched (exception) values.
58    pub patch_indices: ArrayRef,
59    /// The patched (exception) values at the corresponding indices.
60    pub patch_values: ArrayRef,
61}
62
63impl Display for PatchedData {
64    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
65        write!(f, "n_lanes: {}, offset: {}", self.n_lanes, self.offset)
66    }
67}
68
69impl PatchedData {
70    pub(crate) fn validate(
71        &self,
72        dtype: &DType,
73        len: usize,
74        slots: &PatchedSlotsView,
75    ) -> VortexResult<()> {
76        vortex_ensure!(
77            slots.inner.dtype() == dtype,
78            "PatchedArray base dtype {} does not match outer dtype {}",
79            slots.inner.dtype(),
80            dtype
81        );
82        vortex_ensure!(
83            slots.inner.len() == len,
84            "PatchedArray base len {} does not match outer len {}",
85            slots.inner.len(),
86            len
87        );
88        vortex_ensure!(
89            slots.patch_indices.len() == slots.patch_values.len(),
90            "PatchedArray patch indices len {} does not match patch values len {}",
91            slots.patch_indices.len(),
92            slots.patch_values.len()
93        );
94        Ok(())
95    }
96}
97
98pub trait PatchedArrayExt: PatchedArraySlotsExt {
99    #[inline]
100    fn n_lanes(&self) -> usize {
101        self.n_lanes
102    }
103
104    #[inline]
105    fn offset(&self) -> usize {
106        self.offset
107    }
108
109    #[inline]
110    fn lane_range(&self, chunk: usize, lane: usize) -> VortexResult<Range<usize>> {
111        assert!(chunk * 1024 <= self.as_ref().len() + self.offset());
112        assert!(lane < self.n_lanes());
113
114        let start = self
115            .lane_offsets()
116            .scalar_at(chunk * self.n_lanes() + lane)?;
117        let stop = self
118            .lane_offsets()
119            .scalar_at(chunk * self.n_lanes() + lane + 1)?;
120
121        let start = start
122            .as_primitive()
123            .as_::<usize>()
124            .ok_or_else(|| vortex_err!("could not cast lane_offset to usize"))?;
125
126        let stop = stop
127            .as_primitive()
128            .as_::<usize>()
129            .ok_or_else(|| vortex_err!("could not cast lane_offset to usize"))?;
130
131        Ok(start..stop)
132    }
133
134    fn slice_chunks(&self, chunks: Range<usize>) -> VortexResult<Array<Patched>> {
135        let lane_offsets_start = chunks.start * self.n_lanes();
136        let lane_offsets_stop = chunks.end * self.n_lanes() + 1;
137
138        let sliced_lane_offsets = self
139            .lane_offsets()
140            .slice(lane_offsets_start..lane_offsets_stop)?;
141        let indices = self.patch_indices().clone();
142        let values = self.patch_values().clone();
143
144        let begin = (chunks.start * 1024).saturating_sub(self.offset());
145        let end = (chunks.end * 1024)
146            .saturating_sub(self.offset())
147            .min(self.as_ref().len());
148
149        let offset = if chunks.start == 0 { self.offset() } else { 0 };
150        let inner = self.inner().slice(begin..end)?;
151        let len = inner.len();
152        let dtype = self.as_ref().dtype().clone();
153        let slots = PatchedSlots {
154            inner,
155            lane_offsets: sliced_lane_offsets,
156            patch_indices: indices,
157            patch_values: values,
158        }
159        .into_slots();
160
161        Ok(unsafe { Patched::new_unchecked(dtype, len, slots, self.n_lanes(), offset) })
162    }
163}
164
165impl<T: TypedArrayRef<Patched>> PatchedArrayExt for T {}
166
167impl Patched {
168    pub fn from_array_and_patches(
169        inner: ArrayRef,
170        patches: &Patches,
171        ctx: &mut ExecutionCtx,
172    ) -> VortexResult<Array<Patched>> {
173        vortex_ensure!(
174            inner.dtype().eq_with_nullability_superset(patches.dtype()),
175            "array DType must match patches DType"
176        );
177
178        vortex_ensure!(
179            inner.dtype().is_primitive(),
180            "Creating PatchedArray from Patches only supported for primitive arrays"
181        );
182
183        vortex_ensure!(
184            patches.num_patches() <= u32::MAX as usize,
185            "PatchedArray does not support > u32::MAX patch values"
186        );
187
188        vortex_ensure!(
189            patches.values().all_valid()?,
190            "PatchedArray cannot be built from Patches with nulls"
191        );
192
193        let values_ptype = patches.dtype().as_ptype();
194
195        let TransposedPatches {
196            n_lanes,
197            lane_offsets,
198            indices,
199            values,
200        } = transpose_patches(patches, ctx)?;
201
202        let lane_offsets = PrimitiveArray::from_buffer_handle(
203            BufferHandle::new_host(lane_offsets),
204            PType::U32,
205            Validity::NonNullable,
206        )
207        .into_array();
208        let indices = PrimitiveArray::from_buffer_handle(
209            BufferHandle::new_host(indices),
210            PType::U16,
211            Validity::NonNullable,
212        )
213        .into_array();
214        let values = PrimitiveArray::from_buffer_handle(
215            BufferHandle::new_host(values),
216            values_ptype,
217            Validity::NonNullable,
218        )
219        .into_array();
220
221        let dtype = inner.dtype().clone();
222        let len = inner.len();
223        let slots = PatchedSlots {
224            inner,
225            lane_offsets,
226            patch_indices: indices,
227            patch_values: values,
228        }
229        .into_slots();
230        Ok(unsafe { Self::new_unchecked(dtype, len, slots, n_lanes, 0) })
231    }
232
233    pub(crate) unsafe fn new_unchecked(
234        dtype: DType,
235        len: usize,
236        slots: Vec<Option<ArrayRef>>,
237        n_lanes: usize,
238        offset: usize,
239    ) -> Array<Patched> {
240        unsafe {
241            Array::from_parts_unchecked(
242                ArrayParts::new(Patched, dtype, len, PatchedData { n_lanes, offset })
243                    .with_slots(slots),
244            )
245        }
246    }
247}
248
249/// Transpose a set of patches from the default sorted layout into the data parallel layout.
250#[allow(clippy::cognitive_complexity)]
251fn transpose_patches(patches: &Patches, ctx: &mut ExecutionCtx) -> VortexResult<TransposedPatches> {
252    let array_len = patches.array_len();
253    let offset = patches.offset();
254
255    let indices = patches
256        .indices()
257        .clone()
258        .execute::<Canonical>(ctx)?
259        .into_primitive();
260
261    let values = patches
262        .values()
263        .clone()
264        .execute::<Canonical>(ctx)?
265        .into_primitive();
266
267    let indices_ptype = indices.ptype();
268    let values_ptype = values.ptype();
269
270    let indices = indices.buffer_handle().clone().unwrap_host();
271    let values = values.buffer_handle().clone().unwrap_host();
272
273    match_each_unsigned_integer_ptype!(indices_ptype, |I| {
274        match_each_native_ptype!(values_ptype, |V| {
275            let indices: Buffer<I> = Buffer::from_byte_buffer(indices);
276            let values: Buffer<V> = Buffer::from_byte_buffer(values);
277
278            Ok(transpose(
279                indices.as_slice(),
280                values.as_slice(),
281                offset,
282                array_len,
283            ))
284        })
285    })
286}
287
288#[allow(clippy::cast_possible_truncation)]
289fn transpose<I: IntegerPType, V: NativePType>(
290    indices_in: &[I],
291    values_in: &[V],
292    offset: usize,
293    array_len: usize,
294) -> TransposedPatches {
295    // Total number of slots is number of chunks times number of lanes.
296    let n_chunks = array_len.div_ceil(1024);
297    assert!(
298        n_chunks <= u32::MAX as usize,
299        "Cannot transpose patches for array with >= 4 trillion elements"
300    );
301
302    let n_lanes = patch_lanes::<V>();
303
304    // We know upfront how many indices and values we'll have.
305    let mut indices_buffer = BufferMut::with_capacity(indices_in.len());
306    let mut values_buffer = BufferMut::with_capacity(values_in.len());
307
308    // Number of patches in each chunk/lane.
309    let mut lane_offsets: BufferMut<u32> = BufferMut::zeroed(n_chunks * n_lanes + 1);
310
311    // Scan the index/value pairs once to get chunk/lane counts.
312    for index in indices_in {
313        let index = index.as_() - offset;
314        let chunk = index / 1024;
315        let lane = index % n_lanes;
316
317        lane_offsets[chunk * n_lanes + lane + 1] += 1;
318    }
319
320    for index in 1..lane_offsets.len() {
321        lane_offsets[index] += lane_offsets[index - 1];
322    }
323
324    // Loop over patches, writing them to final positions.
325    let indices_out = indices_buffer.spare_capacity_mut();
326    let values_out = values_buffer.spare_capacity_mut();
327    for (index, &value) in std::iter::zip(indices_in, values_in) {
328        let index = index.as_() - offset;
329        let chunk = index / 1024;
330        let lane = index % n_lanes;
331
332        let position = &mut lane_offsets[chunk * n_lanes + lane];
333        indices_out[*position as usize].write((index % 1024) as u16);
334        values_out[*position as usize].write(value);
335        *position += 1;
336    }
337
338    unsafe {
339        indices_buffer.set_len(indices_in.len());
340        values_buffer.set_len(values_in.len());
341    }
342
343    for index in indices_in {
344        let index = index.as_() - offset;
345        let chunk = index / 1024;
346        let lane = index % n_lanes;
347
348        lane_offsets[chunk * n_lanes + lane] -= 1;
349    }
350
351    TransposedPatches {
352        n_lanes,
353        lane_offsets: lane_offsets.freeze().into_byte_buffer(),
354        indices: indices_buffer.freeze().into_byte_buffer(),
355        values: values_buffer.freeze().into_byte_buffer(),
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use vortex_buffer::buffer;
362
363    use super::PatchedSlots;
364    use crate::ArrayRef;
365    use crate::IntoArray;
366    use crate::array_slots;
367    use crate::arrays::Null;
368    use crate::arrays::PrimitiveArray;
369    use crate::validity::Validity;
370
371    #[array_slots(Null)]
372    struct OptionalPatchedSlots {
373        required: ArrayRef,
374        maybe: Option<ArrayRef>,
375    }
376
377    #[test]
378    fn generated_slots_round_trip() {
379        let required = PrimitiveArray::new(buffer![1u8, 2, 3], Validity::NonNullable).into_array();
380        let optional = PrimitiveArray::new(buffer![4u8, 5, 6], Validity::NonNullable).into_array();
381
382        let slot_vec = vec![Some(required.clone()), Some(optional.clone())];
383        let view = OptionalPatchedSlotsView::from_slots(&slot_vec);
384        assert_eq!(view.required.len(), 3);
385        assert_eq!(view.maybe.expect("optional slot").len(), 3);
386
387        let cloned = OptionalPatchedSlots::from_slots(slot_vec);
388        assert_eq!(cloned.required.len(), required.len());
389        assert_eq!(cloned.maybe.expect("optional clone").len(), optional.len());
390
391        let rebuilt = PatchedSlots::from_slots(vec![
392            Some(required.clone()),
393            Some(optional.clone()),
394            Some(required.clone()),
395            Some(optional.clone()),
396        ]);
397        assert_eq!(rebuilt.inner.len(), required.len());
398        assert_eq!(rebuilt.patch_values.len(), optional.len());
399    }
400}