Skip to main content

vortex_array/arrays/patched/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use prost::Message;
5
6use crate::ArrayEq;
7use crate::ArrayHash;
8mod kernels;
9mod operations;
10mod slice;
11
12use std::hash::Hash;
13use std::hash::Hasher;
14
15use vortex_buffer::Buffer;
16use vortex_error::VortexExpect;
17use vortex_error::VortexResult;
18use vortex_error::vortex_panic;
19use vortex_session::VortexSession;
20use vortex_session::registry::CachedId;
21
22use crate::ArrayRef;
23use crate::Canonical;
24use crate::ExecutionCtx;
25use crate::ExecutionResult;
26use crate::IntoArray;
27use crate::Precision;
28use crate::array::Array;
29use crate::array::ArrayId;
30use crate::array::ArrayParts;
31use crate::array::ArrayView;
32use crate::array::VTable;
33use crate::array::ValidityChild;
34use crate::array::ValidityVTableFromChild;
35use crate::arrays::Primitive;
36use crate::arrays::PrimitiveArray;
37use crate::arrays::patched::PatchedArrayExt;
38use crate::arrays::patched::PatchedArraySlotsExt;
39use crate::arrays::patched::PatchedData;
40use crate::arrays::patched::PatchedSlots;
41use crate::arrays::patched::PatchedSlotsView;
42use crate::arrays::patched::compute::rules::PARENT_RULES;
43use crate::arrays::patched::vtable::kernels::PARENT_KERNELS;
44use crate::arrays::primitive::PrimitiveDataParts;
45use crate::buffer::BufferHandle;
46use crate::builders::ArrayBuilder;
47use crate::builders::PrimitiveBuilder;
48use crate::dtype::DType;
49use crate::dtype::NativePType;
50use crate::dtype::PType;
51use crate::match_each_native_ptype;
52use crate::require_child;
53use crate::serde::ArrayChildren;
54
55/// A [`Patched`]-encoded Vortex array.
56pub type PatchedArray = Array<Patched>;
57
58#[derive(Clone, Debug)]
59pub struct Patched;
60
61impl ValidityChild<Patched> for Patched {
62    fn validity_child(array: ArrayView<'_, Patched>) -> ArrayRef {
63        array.inner().clone()
64    }
65}
66
67#[derive(Clone, prost::Message)]
68pub struct PatchedMetadata {
69    /// The total number of patches, and the length of the indices and values child arrays.
70    #[prost(uint32, tag = "1")]
71    pub(crate) n_patches: u32,
72
73    /// The number of lanes used for patch indexing. Must be a power of two between 1 and 128.
74    #[prost(uint32, tag = "2")]
75    pub(crate) n_lanes: u32,
76
77    /// An offset into the first chunk's patches that should be considered in-view.
78    ///
79    /// Always between 0 and 1023.
80    #[prost(uint32, tag = "3")]
81    pub(crate) offset: u32,
82}
83
84impl ArrayHash for PatchedData {
85    fn array_hash<H: Hasher>(&self, state: &mut H, _precision: Precision) {
86        self.offset.hash(state);
87        self.n_lanes.hash(state);
88    }
89}
90
91impl ArrayEq for PatchedData {
92    fn array_eq(&self, other: &Self, _precision: Precision) -> bool {
93        self.offset == other.offset && self.n_lanes == other.n_lanes
94    }
95}
96
97impl VTable for Patched {
98    type ArrayData = PatchedData;
99    type OperationsVTable = Self;
100    type ValidityVTable = ValidityVTableFromChild;
101
102    fn id(&self) -> ArrayId {
103        static ID: CachedId = CachedId::new("vortex.patched");
104        *ID
105    }
106
107    fn validate(
108        &self,
109        data: &PatchedData,
110        dtype: &DType,
111        len: usize,
112        slots: &[Option<ArrayRef>],
113    ) -> VortexResult<()> {
114        data.validate(dtype, len, &PatchedSlotsView::from_slots(slots))
115    }
116
117    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
118        0
119    }
120
121    fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle {
122        vortex_panic!("invalid buffer index for PatchedArray: {idx}");
123    }
124
125    fn buffer_name(_array: ArrayView<'_, Self>, idx: usize) -> Option<String> {
126        vortex_panic!("invalid buffer index for PatchedArray: {idx}");
127    }
128
129    fn child(array: ArrayView<'_, Self>, idx: usize) -> ArrayRef {
130        match idx {
131            PatchedSlots::INNER => array.inner().clone(),
132            PatchedSlots::LANE_OFFSETS => array.lane_offsets().clone(),
133            PatchedSlots::PATCH_INDICES => array.patch_indices().clone(),
134            PatchedSlots::PATCH_VALUES => array.patch_values().clone(),
135            _ => vortex_panic!("invalid child index for PatchedArray: {idx}"),
136        }
137    }
138
139    fn serialize(
140        array: ArrayView<'_, Self>,
141        _session: &VortexSession,
142    ) -> VortexResult<Option<Vec<u8>>> {
143        Ok(Some(
144            PatchedMetadata {
145                n_patches: u32::try_from(array.patch_indices().len())?,
146                n_lanes: u32::try_from(array.n_lanes())?,
147                offset: u32::try_from(array.offset())?,
148            }
149            .encode_to_vec(),
150        ))
151    }
152
153    fn deserialize(
154        &self,
155        dtype: &DType,
156        len: usize,
157        metadata: &[u8],
158        _buffers: &[BufferHandle],
159        children: &dyn ArrayChildren,
160        _session: &VortexSession,
161    ) -> VortexResult<ArrayParts<Self>> {
162        let metadata = PatchedMetadata::decode(metadata)?;
163        let n_patches = metadata.n_patches as usize;
164        let n_lanes = metadata.n_lanes as usize;
165        let offset = metadata.offset as usize;
166
167        // n_chunks should correspond to the chunk in the `inner`.
168        // After slicing when offset > 0, there may be additional chunks.
169        let n_chunks = (len + offset).div_ceil(1024);
170
171        let inner = children.get(0, dtype, len)?;
172        let lane_offsets = children.get(1, PType::U32.into(), n_chunks * n_lanes + 1)?;
173        let indices = children.get(2, PType::U16.into(), n_patches)?;
174        let values = children.get(3, dtype, n_patches)?;
175
176        let data = PatchedData { n_lanes, offset };
177        let slots = PatchedSlots {
178            inner,
179            lane_offsets,
180            patch_indices: indices,
181            patch_values: values,
182        }
183        .into_slots();
184        Ok(ArrayParts::new(self.clone(), dtype.clone(), len, data).with_slots(slots))
185    }
186
187    fn append_to_builder(
188        array: ArrayView<'_, Self>,
189        builder: &mut dyn ArrayBuilder,
190        ctx: &mut ExecutionCtx,
191    ) -> VortexResult<()> {
192        let dtype = array.array().dtype();
193
194        if !dtype.is_primitive() {
195            // Default pathway: canonicalize and propagate.
196            let canonical = array
197                .array()
198                .clone()
199                .execute::<Canonical>(ctx)?
200                .into_array();
201            builder.extend_from_array(&canonical);
202            return Ok(());
203        }
204
205        let ptype = dtype.as_ptype();
206
207        let len = array.len();
208
209        array.inner().append_to_builder(builder, ctx)?;
210
211        let offset = array.offset();
212        let lane_offsets = array
213            .lane_offsets()
214            .clone()
215            .execute::<PrimitiveArray>(ctx)?;
216        let indices = array
217            .patch_indices()
218            .clone()
219            .execute::<PrimitiveArray>(ctx)?;
220        let values = array
221            .patch_values()
222            .clone()
223            .execute::<PrimitiveArray>(ctx)?;
224
225        match_each_native_ptype!(ptype, |V| {
226            let typed_builder = builder
227                .as_any_mut()
228                .downcast_mut::<PrimitiveBuilder<V>>()
229                .vortex_expect("correctly typed builder");
230
231            // Overwrite the last `len` elements of the builder. These would have been
232            // populated by the inner.append_to_builder() call above.
233            let output = typed_builder.values_mut();
234            let trailer = output.len() - len;
235
236            apply_patches_primitive::<V>(
237                &mut output[trailer..],
238                offset,
239                len,
240                array.n_lanes(),
241                lane_offsets.as_slice::<u32>(),
242                indices.as_slice::<u16>(),
243                values.as_slice::<V>(),
244            );
245        });
246
247        Ok(())
248    }
249
250    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
251        PatchedSlots::NAMES[idx].to_string()
252    }
253
254    fn execute(array: Array<Self>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
255        let array = require_child!(array, array.inner(), PatchedSlots::INNER => Primitive);
256        let array =
257            require_child!(array, array.lane_offsets(), PatchedSlots::LANE_OFFSETS => Primitive);
258        let array =
259            require_child!(array, array.patch_indices(), PatchedSlots::PATCH_INDICES => Primitive);
260        let array =
261            require_child!(array, array.patch_values(), PatchedSlots::PATCH_VALUES => Primitive);
262
263        let len = array.len();
264
265        let n_lanes = array.n_lanes;
266        let offset = array.offset;
267        let slots = match array.try_into_parts() {
268            Ok(parts) => PatchedSlots::from_slots(parts.slots),
269            Err(array) => PatchedSlotsView::from_slots(array.slots()).to_owned(),
270        };
271
272        // TODO(joe): use iterative execution
273        let PrimitiveDataParts {
274            buffer,
275            ptype,
276            validity,
277        } = slots.inner.downcast::<Primitive>().into_data_parts();
278
279        let values = slots.patch_values.downcast::<Primitive>();
280        let lane_offsets = slots.lane_offsets.downcast::<Primitive>();
281        let patch_indices = slots.patch_indices.downcast::<Primitive>();
282
283        let patched_values = match_each_native_ptype!(values.ptype(), |V| {
284            let mut output = Buffer::<V>::from_byte_buffer(buffer.unwrap_host()).into_mut();
285
286            apply_patches_primitive::<V>(
287                &mut output,
288                offset,
289                len,
290                n_lanes,
291                lane_offsets.as_slice::<u32>(),
292                patch_indices.as_slice::<u16>(),
293                values.as_slice::<V>(),
294            );
295
296            let output = output.freeze();
297
298            PrimitiveArray::from_byte_buffer(output.into_byte_buffer(), ptype, validity)
299        });
300
301        Ok(ExecutionResult::done(patched_values.into_array()))
302    }
303
304    fn execute_parent(
305        array: ArrayView<'_, Self>,
306        parent: &ArrayRef,
307        child_idx: usize,
308        ctx: &mut ExecutionCtx,
309    ) -> VortexResult<Option<ArrayRef>> {
310        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
311    }
312
313    fn reduce_parent(
314        array: ArrayView<'_, Self>,
315        parent: &ArrayRef,
316        child_idx: usize,
317    ) -> VortexResult<Option<ArrayRef>> {
318        PARENT_RULES.evaluate(array, parent, child_idx)
319    }
320}
321
322/// Apply patches on top of the existing value types.
323fn apply_patches_primitive<V: NativePType>(
324    output: &mut [V],
325    offset: usize,
326    len: usize,
327    n_lanes: usize,
328    lane_offsets: &[u32],
329    indices: &[u16],
330    values: &[V],
331) {
332    let n_chunks = (offset + len).div_ceil(1024);
333    for chunk in 0..n_chunks {
334        let start = lane_offsets[chunk * n_lanes] as usize;
335        let stop = lane_offsets[chunk * n_lanes + n_lanes] as usize;
336
337        for idx in start..stop {
338            // the indices slice is measured as an offset into the 1024-value chunk.
339            let index = chunk * 1024 + indices[idx] as usize;
340            if index < offset || index >= offset + len {
341                continue;
342            }
343
344            let value = values[idx];
345            output[index - offset] = value;
346        }
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use rstest::rstest;
353    use vortex_buffer::ByteBufferMut;
354    use vortex_buffer::buffer;
355    use vortex_buffer::buffer_mut;
356    use vortex_error::VortexResult;
357    use vortex_session::VortexSession;
358    use vortex_session::registry::ReadContext;
359
360    use crate::ArrayContext;
361    use crate::Canonical;
362    use crate::ExecutionCtx;
363    use crate::IntoArray;
364    use crate::LEGACY_SESSION;
365    use crate::arrays::Patched;
366    use crate::arrays::PatchedArray;
367    use crate::arrays::PrimitiveArray;
368    use crate::arrays::patched::PatchedArraySlotsExt;
369    use crate::arrays::patched::PatchedSlots;
370    use crate::arrays::patched::PatchedSlotsView;
371    use crate::assert_arrays_eq;
372    use crate::builders::builder_with_capacity;
373    use crate::patches::Patches;
374    use crate::serde::SerializeOptions;
375    use crate::serde::SerializedArray;
376    use crate::validity::Validity;
377
378    #[test]
379    fn test_execute() {
380        let values = buffer![0u16; 1024].into_array();
381        let patches = Patches::new(
382            1024,
383            0,
384            buffer![1u32, 2, 3].into_array(),
385            buffer![1u16; 3].into_array(),
386            None,
387        )
388        .unwrap();
389
390        let session = VortexSession::empty();
391        let mut ctx = ExecutionCtx::new(session);
392
393        let array = Patched::from_array_and_patches(values, &patches, &mut ctx)
394            .unwrap()
395            .into_array();
396
397        let executed = array
398            .execute::<Canonical>(&mut ctx)
399            .unwrap()
400            .into_primitive()
401            .into_buffer::<u16>();
402
403        let mut expected = buffer_mut![0u16; 1024];
404        expected[1] = 1;
405        expected[2] = 1;
406        expected[3] = 1;
407
408        assert_eq!(executed, expected.freeze());
409    }
410
411    #[test]
412    fn test_execute_sliced() {
413        let values = buffer![0u16; 1024].into_array();
414        let patches = Patches::new(
415            1024,
416            0,
417            buffer![1u32, 2, 3].into_array(),
418            buffer![1u16; 3].into_array(),
419            None,
420        )
421        .unwrap();
422
423        let session = VortexSession::empty();
424        let mut ctx = ExecutionCtx::new(session);
425
426        let array = Patched::from_array_and_patches(values, &patches, &mut ctx)
427            .unwrap()
428            .into_array()
429            .slice(3..1024)
430            .unwrap();
431
432        let executed = array
433            .execute::<Canonical>(&mut ctx)
434            .unwrap()
435            .into_primitive()
436            .into_buffer::<u16>();
437
438        let mut expected = buffer_mut![0u16; 1021];
439        expected[0] = 1;
440
441        assert_eq!(executed, expected.freeze());
442    }
443
444    #[test]
445    fn test_append_to_builder_non_nullable() {
446        let values = PrimitiveArray::new(buffer![0u16; 1024], Validity::NonNullable).into_array();
447        let patches = Patches::new(
448            1024,
449            0,
450            buffer![1u32, 2, 3].into_array(),
451            buffer![10u16, 20, 30].into_array(),
452            None,
453        )
454        .unwrap();
455
456        let session = VortexSession::empty();
457        let mut ctx = ExecutionCtx::new(session);
458
459        let array = Patched::from_array_and_patches(values, &patches, &mut ctx)
460            .unwrap()
461            .into_array();
462
463        let mut builder = builder_with_capacity(array.dtype(), array.len());
464        array.append_to_builder(builder.as_mut(), &mut ctx).unwrap();
465
466        let result = builder.finish();
467
468        let mut expected = buffer_mut![0u16; 1024];
469        expected[1] = 10;
470        expected[2] = 20;
471        expected[3] = 30;
472        let expected = expected.into_array();
473
474        assert_arrays_eq!(expected, result);
475    }
476
477    #[test]
478    fn test_append_to_builder_sliced() {
479        let values = PrimitiveArray::new(buffer![0u16; 1024], Validity::NonNullable).into_array();
480        let patches = Patches::new(
481            1024,
482            0,
483            buffer![1u32, 2, 3].into_array(),
484            buffer![10u16, 20, 30].into_array(),
485            None,
486        )
487        .unwrap();
488
489        let session = VortexSession::empty();
490        let mut ctx = ExecutionCtx::new(session);
491
492        let array = Patched::from_array_and_patches(values, &patches, &mut ctx)
493            .unwrap()
494            .into_array()
495            .slice(3..1024)
496            .unwrap();
497
498        let mut builder = builder_with_capacity(array.dtype(), array.len());
499        array.append_to_builder(builder.as_mut(), &mut ctx).unwrap();
500
501        let result = builder.finish();
502
503        let mut expected = buffer_mut![0u16; 1021];
504        expected[0] = 30;
505        let expected = expected.into_array();
506
507        assert_arrays_eq!(expected, result);
508    }
509
510    #[test]
511    fn test_append_to_builder_with_validity() {
512        // Create inner array with nulls at indices 0 and 5.
513        let validity = Validity::from_iter((0..10).map(|i| i != 0 && i != 5));
514        let values = PrimitiveArray::new(buffer![0u16; 10], validity).into_array();
515
516        // Apply patches at indices 1, 2, 3.
517        let patches = Patches::new(
518            10,
519            0,
520            buffer![1u32, 2, 3].into_array(),
521            buffer![10u16, 20, 30].into_array(),
522            None,
523        )
524        .unwrap();
525
526        let session = VortexSession::empty();
527        let mut ctx = ExecutionCtx::new(session);
528
529        let array = Patched::from_array_and_patches(values, &patches, &mut ctx)
530            .unwrap()
531            .into_array();
532
533        let mut builder = builder_with_capacity(array.dtype(), array.len());
534        array.append_to_builder(builder.as_mut(), &mut ctx).unwrap();
535
536        let result = builder.finish();
537
538        // Expected: null at 0, patched 10/20/30 at 1/2/3, zero at 4, null at 5, zeros at 6-9.
539        let expected = PrimitiveArray::from_option_iter([
540            None,
541            Some(10u16),
542            Some(20),
543            Some(30),
544            Some(0),
545            None,
546            Some(0),
547            Some(0),
548            Some(0),
549            Some(0),
550        ])
551        .into_array();
552
553        assert_arrays_eq!(expected, result);
554    }
555
556    fn make_patched_array(
557        inner: impl IntoIterator<Item = u16>,
558        patch_indices: &[u32],
559        patch_values: &[u16],
560    ) -> VortexResult<PatchedArray> {
561        let values: Vec<u16> = inner.into_iter().collect();
562        let len = values.len();
563        let array = PrimitiveArray::from_iter(values).into_array();
564
565        let indices = PrimitiveArray::from_iter(patch_indices.iter().copied()).into_array();
566        let patch_vals = PrimitiveArray::from_iter(patch_values.iter().copied()).into_array();
567
568        let patches = Patches::new(len, 0, indices, patch_vals, None)?;
569
570        let session = VortexSession::empty();
571        let mut ctx = ExecutionCtx::new(session);
572
573        Patched::from_array_and_patches(array, &patches, &mut ctx)
574    }
575
576    #[rstest]
577    #[case::basic(
578        make_patched_array(vec![0u16; 1024], &[1, 2, 3], &[10, 20, 30]).unwrap().into_array()
579    )]
580    #[case::multi_chunk(
581        make_patched_array(vec![0u16; 4096], &[100, 1500, 2500, 3500], &[11, 22, 33, 44]).unwrap().into_array()
582    )]
583    #[case::sliced({
584        let arr = make_patched_array(vec![0u16; 1024], &[1, 2, 3], &[10, 20, 30]).unwrap();
585        arr.into_array().slice(2..1024).unwrap()
586    })]
587    fn test_serde_roundtrip(#[case] array: crate::ArrayRef) {
588        let dtype = array.dtype().clone();
589        let len = array.len();
590
591        let ctx = ArrayContext::empty();
592        let serialized = array
593            .serialize(&ctx, &LEGACY_SESSION, &SerializeOptions::default())
594            .unwrap();
595
596        // Concat into a single buffer.
597        let mut concat = ByteBufferMut::empty();
598        for buf in serialized {
599            concat.extend_from_slice(buf.as_ref());
600        }
601        let concat = concat.freeze();
602
603        let parts = SerializedArray::try_from(concat).unwrap();
604        let decoded = parts
605            .decode(
606                &dtype,
607                len,
608                &ReadContext::new(ctx.to_ids()),
609                &LEGACY_SESSION,
610            )
611            .unwrap();
612
613        assert!(decoded.is::<Patched>());
614        assert_eq!(
615            array.display_values().to_string(),
616            decoded.display_values().to_string()
617        );
618    }
619
620    #[test]
621    fn test_with_slots_basic() -> VortexResult<()> {
622        let array = make_patched_array(vec![0u16; 1024], &[1, 2, 3], &[10, 20, 30])?;
623
624        // Get original children via accessor methods
625        let slots = PatchedSlots::from_slots(array.as_array().slots().to_vec());
626        let view = PatchedSlotsView::from_slots(array.as_array().slots());
627        assert_eq!(view.inner.len(), array.inner().len());
628
629        // Create new PatchedArray with same children using with_slots
630        let array_ref = array.into_array();
631        let new_array = array_ref.clone().with_slots(slots.into_slots())?;
632
633        assert!(new_array.is::<Patched>());
634        assert_eq!(array_ref.len(), new_array.len());
635        assert_eq!(array_ref.dtype(), new_array.dtype());
636
637        // Execute both and compare results
638        let mut ctx = ExecutionCtx::new(VortexSession::empty());
639        let original_executed = array_ref.execute::<Canonical>(&mut ctx)?.into_primitive();
640        let new_executed = new_array.execute::<Canonical>(&mut ctx)?.into_primitive();
641
642        assert_arrays_eq!(original_executed, new_executed);
643
644        Ok(())
645    }
646
647    #[test]
648    fn test_with_slots_modified_inner() -> VortexResult<()> {
649        let array = make_patched_array(vec![0u16; 10], &[1, 2, 3], &[10, 20, 30])?;
650
651        // Create a different inner array (all 5s instead of 0s)
652        let new_inner = PrimitiveArray::from_iter(vec![5u16; 10]).into_array();
653        let slots = PatchedSlots {
654            inner: new_inner,
655            lane_offsets: array.lane_offsets().clone(),
656            patch_indices: array.patch_indices().clone(),
657            patch_values: array.patch_values().clone(),
658        };
659
660        let array_ref = array.into_array();
661        let new_array = array_ref.with_slots(slots.into_slots())?;
662
663        // Execute and verify the inner values changed (except at patch positions)
664        let mut ctx = ExecutionCtx::new(VortexSession::empty());
665        let executed = new_array.execute::<Canonical>(&mut ctx)?.into_primitive();
666
667        // Expected: all 5s except indices 1, 2, 3 which are patched to 10, 20, 30
668        let expected = PrimitiveArray::from_iter([5u16, 10, 20, 30, 5, 5, 5, 5, 5, 5]);
669        assert_arrays_eq!(expected, executed);
670
671        Ok(())
672    }
673}