vortex_array/arrays/primitive/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::fmt::Debug;
5use std::iter;
6
7mod accessor;
8
9use arrow_buffer::BooleanBufferBuilder;
10use vortex_buffer::{Alignment, Buffer, BufferMut, ByteBuffer, ByteBufferMut};
11use vortex_dtype::{DType, NativePType, Nullability, PType, match_each_native_ptype};
12use vortex_error::{VortexExpect, VortexResult, vortex_err, vortex_panic};
13
14use crate::builders::ArrayBuilder;
15use crate::stats::{ArrayStats, StatsSetRef};
16use crate::validity::Validity;
17use crate::{ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable};
18
19mod compute;
20mod downcast;
21mod native_value;
22mod operator;
23mod ops;
24mod patch;
25mod serde;
26mod top_value;
27
28pub use compute::{IS_CONST_LANE_WIDTH, compute_is_constant};
29pub use native_value::NativeValue;
30
31use crate::vtable::{
32    ArrayVTable, CanonicalVTable, NotSupported, VTable, ValidityHelper,
33    ValidityVTableFromValidityHelper,
34};
35
36vtable!(Primitive);
37
38impl VTable for PrimitiveVTable {
39    type Array = PrimitiveArray;
40    type Encoding = PrimitiveEncoding;
41
42    type ArrayVTable = Self;
43    type CanonicalVTable = Self;
44    type OperationsVTable = Self;
45    type ValidityVTable = ValidityVTableFromValidityHelper;
46    type VisitorVTable = Self;
47    type ComputeVTable = NotSupported;
48    type EncodeVTable = NotSupported;
49    type PipelineVTable = Self;
50    type SerdeVTable = Self;
51
52    fn id(_encoding: &Self::Encoding) -> EncodingId {
53        EncodingId::new_ref("vortex.primitive")
54    }
55
56    fn encoding(_array: &Self::Array) -> EncodingRef {
57        EncodingRef::new_ref(PrimitiveEncoding.as_ref())
58    }
59}
60
61/// A primitive array that stores [native types][vortex_dtype::NativePType] in a contiguous buffer
62/// of memory, along with an optional validity child.
63///
64/// This mirrors the Apache Arrow Primitive layout and can be converted into and out of one
65/// without allocations or copies.
66///
67/// The underlying buffer must be natively aligned to the primitive type they are representing.
68///
69/// Values are stored in their native representation with proper alignment.
70/// Null values still occupy space in the buffer but are marked invalid in the validity mask.
71///
72/// # Examples
73///
74/// ```
75/// use vortex_array::arrays::PrimitiveArray;
76/// use vortex_array::compute::sum;
77/// ///
78/// // Create from iterator using FromIterator impl
79/// let array: PrimitiveArray = [1i32, 2, 3, 4, 5].into_iter().collect();
80///
81/// // Slice the array
82/// let sliced = array.slice(1..3);
83///
84/// // Access individual values
85/// let value = sliced.scalar_at(0);
86/// assert_eq!(value, 2i32.into());
87///
88/// // Convert into a type-erased array that can be passed to compute functions.
89/// let summed = sum(sliced.as_ref()).unwrap().as_primitive().typed_value::<i64>().unwrap();
90/// assert_eq!(summed, 5i64);
91/// ```
92#[derive(Clone, Debug)]
93pub struct PrimitiveArray {
94    dtype: DType,
95    buffer: ByteBuffer,
96    validity: Validity,
97    stats_set: ArrayStats,
98}
99
100#[derive(Clone, Debug)]
101pub struct PrimitiveEncoding;
102
103// TODO(connor): There are a lot of places where we could be using `new_unchecked` in the codebase.
104impl PrimitiveArray {
105    /// Creates a new [`PrimitiveArray`].
106    ///
107    /// # Panics
108    ///
109    /// Panics if the provided components do not satisfy the invariants documented
110    /// in [`PrimitiveArray::new_unchecked`].
111    pub fn new<T: NativePType>(buffer: impl Into<Buffer<T>>, validity: Validity) -> Self {
112        let buffer = buffer.into();
113        Self::try_new(buffer, validity).vortex_expect("PrimitiveArray construction failed")
114    }
115
116    /// Constructs a new `PrimitiveArray`.
117    ///
118    /// See [`PrimitiveArray::new_unchecked`] for more information.
119    ///
120    /// # Errors
121    ///
122    /// Returns an error if the provided components do not satisfy the invariants documented in
123    /// [`PrimitiveArray::new_unchecked`].
124    #[inline]
125    pub fn try_new<T: NativePType>(buffer: Buffer<T>, validity: Validity) -> VortexResult<Self> {
126        Self::validate(&buffer, &validity)?;
127
128        // SAFETY: validate ensures all invariants are met.
129        Ok(unsafe { Self::new_unchecked(buffer, validity) })
130    }
131
132    /// Creates a new [`PrimitiveArray`] without validation from these components:
133    ///
134    /// * `buffer` is a typed buffer containing the primitive values.
135    /// * `validity` holds the null values.
136    ///
137    /// # Safety
138    ///
139    /// The caller must ensure all of the following invariants are satisfied:
140    ///
141    /// ## Validity Requirements
142    ///
143    /// - If `validity` is [`Validity::Array`], its length must exactly equal `buffer.len()`.
144    #[inline]
145    pub unsafe fn new_unchecked<T: NativePType>(buffer: Buffer<T>, validity: Validity) -> Self {
146        Self {
147            dtype: DType::Primitive(T::PTYPE, validity.nullability()),
148            buffer: buffer.into_byte_buffer(),
149            validity,
150            stats_set: Default::default(),
151        }
152    }
153
154    /// Validates the components that would be used to create a [`PrimitiveArray`].
155    ///
156    /// This function checks all the invariants required by [`PrimitiveArray::new_unchecked`].
157    #[inline]
158    pub(crate) fn validate<T: NativePType>(
159        buffer: &Buffer<T>,
160        validity: &Validity,
161    ) -> VortexResult<()> {
162        if let Some(len) = validity.maybe_len()
163            && buffer.len() != len
164        {
165            return Err(vortex_err!(
166                "Buffer and validity length mismatch: buffer={}, validity={}",
167                buffer.len(),
168                len
169            ));
170        }
171        Ok(())
172    }
173
174    pub fn empty<T: NativePType>(nullability: Nullability) -> Self {
175        Self::new(Buffer::<T>::empty(), nullability.into())
176    }
177
178    pub fn from_byte_buffer(buffer: ByteBuffer, ptype: PType, validity: Validity) -> Self {
179        match_each_native_ptype!(ptype, |T| {
180            Self::new::<T>(Buffer::from_byte_buffer(buffer), validity)
181        })
182    }
183
184    /// Create a PrimitiveArray from an iterator of `T`.
185    /// NOTE: we cannot impl FromIterator trait since it conflicts with `FromIterator<T>`.
186    pub fn from_option_iter<T: NativePType, I: IntoIterator<Item = Option<T>>>(iter: I) -> Self {
187        let iter = iter.into_iter();
188        let mut values = BufferMut::with_capacity(iter.size_hint().0);
189        let mut validity = BooleanBufferBuilder::new(values.capacity());
190
191        for i in iter {
192            match i {
193                None => {
194                    validity.append(false);
195                    values.push(T::default());
196                }
197                Some(e) => {
198                    validity.append(true);
199                    values.push(e);
200                }
201            }
202        }
203        Self::new(values.freeze(), Validity::from(validity.finish()))
204    }
205
206    /// Create a PrimitiveArray from a byte buffer containing only the valid elements.
207    pub fn from_values_byte_buffer(
208        valid_elems_buffer: ByteBuffer,
209        ptype: PType,
210        validity: Validity,
211        n_rows: usize,
212    ) -> Self {
213        let byte_width = ptype.byte_width();
214        let alignment = Alignment::new(byte_width);
215        let buffer = match &validity {
216            Validity::AllValid | Validity::NonNullable => valid_elems_buffer.aligned(alignment),
217            Validity::AllInvalid => ByteBuffer::zeroed_aligned(n_rows * byte_width, alignment),
218            Validity::Array(is_valid) => {
219                let bool_array = is_valid.to_bool();
220                let bool_buffer = bool_array.boolean_buffer();
221                let mut bytes = ByteBufferMut::zeroed_aligned(n_rows * byte_width, alignment);
222                for (i, valid_i) in bool_buffer.set_indices().enumerate() {
223                    bytes[valid_i * byte_width..(valid_i + 1) * byte_width]
224                        .copy_from_slice(&valid_elems_buffer[i * byte_width..(i + 1) * byte_width])
225                }
226                bytes.freeze()
227            }
228        };
229
230        Self::from_byte_buffer(buffer, ptype, validity)
231    }
232
233    pub fn ptype(&self) -> PType {
234        self.dtype().as_ptype()
235    }
236
237    pub fn byte_buffer(&self) -> &ByteBuffer {
238        &self.buffer
239    }
240
241    pub fn into_byte_buffer(self) -> ByteBuffer {
242        self.buffer
243    }
244
245    pub fn buffer<T: NativePType>(&self) -> Buffer<T> {
246        if T::PTYPE != self.ptype() {
247            vortex_panic!(
248                "Attempted to get buffer of type {} from array of type {}",
249                T::PTYPE,
250                self.ptype()
251            )
252        }
253        Buffer::from_byte_buffer(self.byte_buffer().clone())
254    }
255
256    pub fn into_buffer<T: NativePType>(self) -> Buffer<T> {
257        if T::PTYPE != self.ptype() {
258            vortex_panic!(
259                "Attempted to get buffer of type {} from array of type {}",
260                T::PTYPE,
261                self.ptype()
262            )
263        }
264        Buffer::from_byte_buffer(self.buffer)
265    }
266
267    /// Extract a mutable buffer from the PrimitiveArray. Attempts to do this with zero-copy
268    /// if the buffer is uniquely owned, otherwise will make a copy.
269    pub fn into_buffer_mut<T: NativePType>(self) -> BufferMut<T> {
270        if T::PTYPE != self.ptype() {
271            vortex_panic!(
272                "Attempted to get buffer_mut of type {} from array of type {}",
273                T::PTYPE,
274                self.ptype()
275            )
276        }
277        self.into_buffer()
278            .try_into_mut()
279            .unwrap_or_else(|buffer| BufferMut::<T>::copy_from(&buffer))
280    }
281
282    /// Try to extract a mutable buffer from the PrimitiveArray with zero copy.
283    #[allow(clippy::panic_in_result_fn)]
284    pub fn try_into_buffer_mut<T: NativePType>(self) -> Result<BufferMut<T>, PrimitiveArray> {
285        if T::PTYPE != self.ptype() {
286            vortex_panic!(
287                "Attempted to get buffer_mut of type {} from array of type {}",
288                T::PTYPE,
289                self.ptype()
290            )
291        }
292        let validity = self.validity().clone();
293        Buffer::<T>::from_byte_buffer(self.into_byte_buffer())
294            .try_into_mut()
295            .map_err(|buffer| PrimitiveArray::new(buffer, validity))
296    }
297
298    /// Map each element in the array to a new value.
299    ///
300    /// This ignores validity and maps over all maybe-null elements.
301    ///
302    /// TODO(ngates): we could be smarter here if validity is sparse and only run the function
303    ///   over the valid elements.
304    pub fn map_each<T, R, F>(self, f: F) -> PrimitiveArray
305    where
306        T: NativePType,
307        R: NativePType,
308        F: FnMut(T) -> R,
309    {
310        let validity = self.validity().clone();
311        let buffer = match self.try_into_buffer_mut() {
312            Ok(buffer_mut) => buffer_mut.map_each(f),
313            Err(parray) => BufferMut::<R>::from_iter(parray.buffer::<T>().iter().copied().map(f)),
314        };
315        PrimitiveArray::new(buffer.freeze(), validity)
316    }
317
318    /// Map each element in the array to a new value.
319    ///
320    /// This doesn't ignore validity and maps over all maybe-null elements, with a bool true if
321    /// valid and false otherwise.
322    pub fn map_each_with_validity<T, R, F>(self, f: F) -> VortexResult<PrimitiveArray>
323    where
324        T: NativePType,
325        R: NativePType,
326        F: FnMut((T, bool)) -> R,
327    {
328        let validity = self.validity();
329
330        let buf_iter = self.buffer::<T>().into_iter();
331
332        let buffer = match &validity {
333            Validity::NonNullable | Validity::AllValid => {
334                BufferMut::<R>::from_iter(buf_iter.zip(iter::repeat(true)).map(f))
335            }
336            Validity::AllInvalid => {
337                BufferMut::<R>::from_iter(buf_iter.zip(iter::repeat(false)).map(f))
338            }
339            Validity::Array(val) => {
340                let val = val.to_bool();
341                BufferMut::<R>::from_iter(buf_iter.zip(val.boolean_buffer()).map(f))
342            }
343        };
344        Ok(PrimitiveArray::new(buffer.freeze(), validity.clone()))
345    }
346
347    /// Return a slice of the array's buffer.
348    ///
349    /// NOTE: these values may be nonsense if the validity buffer indicates that the value is null.
350    pub fn as_slice<T: NativePType>(&self) -> &[T] {
351        if T::PTYPE != self.ptype() {
352            vortex_panic!(
353                "Attempted to get slice of type {} from array of type {}",
354                T::PTYPE,
355                self.ptype()
356            )
357        }
358        let raw_slice = self.byte_buffer().as_ptr();
359        // SAFETY: alignment of Buffer is checked on construction
360        unsafe {
361            std::slice::from_raw_parts(raw_slice.cast(), self.byte_buffer().len() / size_of::<T>())
362        }
363    }
364
365    pub fn reinterpret_cast(&self, ptype: PType) -> Self {
366        if self.ptype() == ptype {
367            return self.clone();
368        }
369
370        assert_eq!(
371            self.ptype().byte_width(),
372            ptype.byte_width(),
373            "can't reinterpret cast between integers of two different widths"
374        );
375
376        PrimitiveArray::from_byte_buffer(self.byte_buffer().clone(), ptype, self.validity().clone())
377    }
378}
379
380impl ArrayVTable<PrimitiveVTable> for PrimitiveVTable {
381    fn len(array: &PrimitiveArray) -> usize {
382        array.byte_buffer().len() / array.ptype().byte_width()
383    }
384
385    fn dtype(array: &PrimitiveArray) -> &DType {
386        &array.dtype
387    }
388
389    fn stats(array: &PrimitiveArray) -> StatsSetRef<'_> {
390        array.stats_set.to_ref(array.as_ref())
391    }
392}
393
394impl ValidityHelper for PrimitiveArray {
395    fn validity(&self) -> &Validity {
396        &self.validity
397    }
398}
399
400impl<T: NativePType> FromIterator<T> for PrimitiveArray {
401    fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
402        let values = BufferMut::from_iter(iter);
403        PrimitiveArray::new(values, Validity::NonNullable)
404    }
405}
406
407impl<T: NativePType> IntoArray for Buffer<T> {
408    fn into_array(self) -> ArrayRef {
409        PrimitiveArray::new(self, Validity::NonNullable).into_array()
410    }
411}
412
413impl<T: NativePType> IntoArray for BufferMut<T> {
414    fn into_array(self) -> ArrayRef {
415        self.freeze().into_array()
416    }
417}
418
419impl CanonicalVTable<PrimitiveVTable> for PrimitiveVTable {
420    fn canonicalize(array: &PrimitiveArray) -> Canonical {
421        Canonical::Primitive(array.clone())
422    }
423
424    fn append_to_builder(array: &PrimitiveArray, builder: &mut dyn ArrayBuilder) {
425        builder.extend_from_array(array.as_ref())
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use vortex_buffer::buffer;
432    use vortex_scalar::PValue;
433
434    use crate::arrays::{BoolArray, PrimitiveArray};
435    use crate::compute::conformance::filter::test_filter_conformance;
436    use crate::compute::conformance::mask::test_mask_conformance;
437    use crate::compute::conformance::search_sorted::rstest_reuse::apply;
438    use crate::compute::conformance::search_sorted::{search_sorted_conformance, *};
439    use crate::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};
440    use crate::validity::Validity;
441    use crate::{ArrayRef, IntoArray};
442
443    #[apply(search_sorted_conformance)]
444    fn test_search_sorted_primitive(
445        #[case] array: ArrayRef,
446        #[case] value: i32,
447        #[case] side: SearchSortedSide,
448        #[case] expected: SearchResult,
449    ) {
450        let res = array
451            .as_primitive_typed()
452            .search_sorted(&Some(PValue::from(value)), side);
453        assert_eq!(res, expected);
454    }
455
456    #[test]
457    fn test_mask_primitive_array() {
458        test_mask_conformance(
459            PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::NonNullable).as_ref(),
460        );
461        test_mask_conformance(
462            PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllValid).as_ref(),
463        );
464        test_mask_conformance(
465            PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllInvalid).as_ref(),
466        );
467        test_mask_conformance(
468            PrimitiveArray::new(
469                buffer![0, 1, 2, 3, 4],
470                Validity::Array(
471                    BoolArray::from_iter([true, false, true, false, true]).into_array(),
472                ),
473            )
474            .as_ref(),
475        );
476    }
477
478    #[test]
479    fn test_filter_primitive_array() {
480        // Test various sizes
481        test_filter_conformance(
482            PrimitiveArray::new(buffer![42i32], Validity::NonNullable).as_ref(),
483        );
484        test_filter_conformance(PrimitiveArray::new(buffer![0, 1], Validity::NonNullable).as_ref());
485        test_filter_conformance(
486            PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::NonNullable).as_ref(),
487        );
488        test_filter_conformance(
489            PrimitiveArray::new(buffer![0, 1, 2, 3, 4, 5, 6, 7], Validity::NonNullable).as_ref(),
490        );
491
492        // Test with validity
493        test_filter_conformance(
494            PrimitiveArray::new(buffer![0, 1, 2, 3, 4], Validity::AllValid).as_ref(),
495        );
496        test_filter_conformance(
497            PrimitiveArray::new(
498                buffer![0, 1, 2, 3, 4, 5],
499                Validity::Array(
500                    BoolArray::from_iter([true, false, true, false, true, true]).into_array(),
501                ),
502            )
503            .as_ref(),
504        );
505    }
506}