vortex_array/array/
mod.rs

1mod convert;
2mod statistics;
3mod visitor;
4
5use std::any::Any;
6use std::fmt::{Debug, Display, Formatter};
7use std::sync::Arc;
8
9pub use convert::*;
10pub use visitor::*;
11use vortex_buffer::ByteBuffer;
12use vortex_dtype::DType;
13use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
14use vortex_mask::Mask;
15use vortex_scalar::Scalar;
16
17use crate::arrays::{
18    BoolEncoding, DecimalEncoding, ExtensionEncoding, ListEncoding, NullEncoding,
19    PrimitiveEncoding, StructEncoding, VarBinEncoding, VarBinViewEncoding,
20};
21use crate::builders::ArrayBuilder;
22use crate::compute::{ComputeFn, Cost, InvocationArgs, Output};
23use crate::serde::ArrayChildren;
24use crate::stats::{Precision, Stat, StatsProviderExt, StatsSetRef};
25use crate::vtable::{
26    ArrayVTable, CanonicalVTable, ComputeVTable, OperationsVTable, SerdeVTable, VTable,
27    ValidityVTable, VisitorVTable,
28};
29use crate::{Canonical, EncodingId, EncodingRef, SerializeMetadata};
30
31/// The public API trait for all Vortex arrays.
32pub trait Array: 'static + private::Sealed + Send + Sync + Debug + ArrayVisitor {
33    /// Returns the array as a reference to a generic [`Any`] trait object.
34    fn as_any(&self) -> &dyn Any;
35
36    /// Returns the array as an [`ArrayRef`].
37    fn to_array(&self) -> ArrayRef;
38
39    /// Returns the length of the array.
40    fn len(&self) -> usize;
41
42    /// Returns whether the array is empty (has zero rows).
43    fn is_empty(&self) -> bool {
44        self.len() == 0
45    }
46
47    /// Returns the logical Vortex [`DType`] of the array.
48    fn dtype(&self) -> &DType;
49
50    /// Returns the encoding of the array.
51    fn encoding(&self) -> EncodingRef;
52
53    /// Returns the encoding ID of the array.
54    fn encoding_id(&self) -> EncodingId;
55
56    /// Performs a constant-time slice of the array.
57    fn slice(&self, start: usize, end: usize) -> VortexResult<ArrayRef>;
58
59    /// Fetch the scalar at the given index.
60    fn scalar_at(&self, index: usize) -> VortexResult<Scalar>;
61
62    /// Returns whether the array is of the given encoding.
63    fn is_encoding(&self, encoding: EncodingId) -> bool {
64        self.encoding_id() == encoding
65    }
66
67    /// Returns whether this array is an arrow encoding.
68    // TODO(ngates): this shouldn't live here.
69    fn is_arrow(&self) -> bool {
70        self.is_encoding(NullEncoding.id())
71            || self.is_encoding(BoolEncoding.id())
72            || self.is_encoding(PrimitiveEncoding.id())
73            || self.is_encoding(VarBinEncoding.id())
74            || self.is_encoding(VarBinViewEncoding.id())
75    }
76
77    /// Whether the array is of a canonical encoding.
78    // TODO(ngates): this shouldn't live here.
79    fn is_canonical(&self) -> bool {
80        self.is_encoding(NullEncoding.id())
81            || self.is_encoding(BoolEncoding.id())
82            || self.is_encoding(PrimitiveEncoding.id())
83            || self.is_encoding(DecimalEncoding.id())
84            || self.is_encoding(StructEncoding.id())
85            || self.is_encoding(ListEncoding.id())
86            || self.is_encoding(VarBinViewEncoding.id())
87            || self.is_encoding(ExtensionEncoding.id())
88    }
89
90    /// Returns whether the item at `index` is valid.
91    fn is_valid(&self, index: usize) -> VortexResult<bool>;
92
93    /// Returns whether the item at `index` is invalid.
94    fn is_invalid(&self, index: usize) -> VortexResult<bool>;
95
96    /// Returns whether all items in the array are valid.
97    ///
98    /// This is usually cheaper than computing a precise `valid_count`.
99    fn all_valid(&self) -> VortexResult<bool>;
100
101    /// Returns whether the array is all invalid.
102    ///
103    /// This is usually cheaper than computing a precise `invalid_count`.
104    fn all_invalid(&self) -> VortexResult<bool>;
105
106    /// Returns the number of valid elements in the array.
107    fn valid_count(&self) -> VortexResult<usize>;
108
109    /// Returns the number of invalid elements in the array.
110    fn invalid_count(&self) -> VortexResult<usize>;
111
112    /// Returns the canonical validity mask for the array.
113    fn validity_mask(&self) -> VortexResult<Mask>;
114
115    /// Returns the canonical representation of the array.
116    fn to_canonical(&self) -> VortexResult<Canonical>;
117
118    /// Writes the array into the canonical builder.
119    ///
120    /// The [`DType`] of the builder must match that of the array.
121    fn append_to_builder(&self, builder: &mut dyn ArrayBuilder) -> VortexResult<()>;
122
123    /// Returns the statistics of the array.
124    // TODO(ngates): change how this works. It's weird.
125    fn statistics(&self) -> StatsSetRef<'_>;
126
127    /// Replaces the children of the array with the given array references.
128    fn with_children(&self, children: &[ArrayRef]) -> VortexResult<ArrayRef>;
129
130    /// Optionally invoke a kernel for the given compute function.
131    ///
132    /// These encoding-specific kernels are independent of kernels registered directly with
133    /// compute functions using [`ComputeFn::register_kernel`], and are attempted only if none of
134    /// the function-specific kernels returns a result.
135    ///
136    /// This allows encodings the opportunity to generically implement many compute functions
137    /// that share some property, for example [`ComputeFn::is_elementwise`], without prior
138    /// knowledge of the function itself, while still allowing users to override the implementation
139    /// of compute functions for built-in encodings. For an example, see the implementation for
140    /// chunked arrays.
141    ///
142    /// The first input in the [`InvocationArgs`] is always the array itself.
143    ///
144    /// Warning: do not call `compute_fn.invoke(args)` directly, as this will result in a recursive
145    /// call.
146    fn invoke(&self, compute_fn: &ComputeFn, args: &InvocationArgs)
147    -> VortexResult<Option<Output>>;
148}
149
150impl Array for Arc<dyn Array> {
151    fn as_any(&self) -> &dyn Any {
152        self.as_ref().as_any()
153    }
154
155    fn to_array(&self) -> ArrayRef {
156        self.clone()
157    }
158
159    fn len(&self) -> usize {
160        self.as_ref().len()
161    }
162
163    fn dtype(&self) -> &DType {
164        self.as_ref().dtype()
165    }
166
167    fn encoding(&self) -> EncodingRef {
168        self.as_ref().encoding()
169    }
170
171    fn encoding_id(&self) -> EncodingId {
172        self.as_ref().encoding_id()
173    }
174
175    fn slice(&self, start: usize, end: usize) -> VortexResult<ArrayRef> {
176        self.as_ref().slice(start, end)
177    }
178
179    fn scalar_at(&self, index: usize) -> VortexResult<Scalar> {
180        self.as_ref().scalar_at(index)
181    }
182
183    fn is_valid(&self, index: usize) -> VortexResult<bool> {
184        self.as_ref().is_valid(index)
185    }
186
187    fn is_invalid(&self, index: usize) -> VortexResult<bool> {
188        self.as_ref().is_invalid(index)
189    }
190
191    fn all_valid(&self) -> VortexResult<bool> {
192        self.as_ref().all_valid()
193    }
194
195    fn all_invalid(&self) -> VortexResult<bool> {
196        self.as_ref().all_invalid()
197    }
198
199    fn valid_count(&self) -> VortexResult<usize> {
200        self.as_ref().valid_count()
201    }
202
203    fn invalid_count(&self) -> VortexResult<usize> {
204        self.as_ref().invalid_count()
205    }
206
207    fn validity_mask(&self) -> VortexResult<Mask> {
208        self.as_ref().validity_mask()
209    }
210
211    fn to_canonical(&self) -> VortexResult<Canonical> {
212        self.as_ref().to_canonical()
213    }
214
215    fn append_to_builder(&self, builder: &mut dyn ArrayBuilder) -> VortexResult<()> {
216        self.as_ref().append_to_builder(builder)
217    }
218
219    fn statistics(&self) -> StatsSetRef<'_> {
220        self.as_ref().statistics()
221    }
222
223    fn with_children(&self, children: &[ArrayRef]) -> VortexResult<ArrayRef> {
224        self.as_ref().with_children(children)
225    }
226
227    fn invoke(
228        &self,
229        compute_fn: &ComputeFn,
230        args: &InvocationArgs,
231    ) -> VortexResult<Option<Output>> {
232        self.as_ref().invoke(compute_fn, args)
233    }
234}
235
236/// A reference counted pointer to a dynamic [`Array`] trait object.
237pub type ArrayRef = Arc<dyn Array>;
238
239impl ToOwned for dyn Array {
240    type Owned = ArrayRef;
241
242    fn to_owned(&self) -> Self::Owned {
243        self.to_array()
244    }
245}
246
247impl dyn Array + '_ {
248    /// Returns the array downcast to the given `A`.
249    pub fn as_<V: VTable>(&self) -> &V::Array {
250        self.as_opt::<V>().vortex_expect("Failed to downcast")
251    }
252
253    /// Returns the array downcast to the given `A`.
254    pub fn as_opt<V: VTable>(&self) -> Option<&V::Array> {
255        self.as_any()
256            .downcast_ref::<ArrayAdapter<V>>()
257            .map(|array_adapter| &array_adapter.0)
258    }
259
260    /// Is self an array with encoding from vtable `V`.
261    pub fn is<V: VTable>(&self) -> bool {
262        self.as_opt::<V>().is_some()
263    }
264}
265
266impl Display for dyn Array {
267    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
268        write!(
269            f,
270            "{}({}, len={})",
271            self.encoding_id(),
272            self.dtype(),
273            self.len()
274        )
275    }
276}
277
278impl dyn Array + '_ {
279    /// Total size of the array in bytes, including all children and buffers.
280    // TODO(ngates): this should return u64
281    pub fn nbytes(&self) -> usize {
282        let mut nbytes = 0;
283        for array in self.depth_first_traversal() {
284            for buffer in array.buffers() {
285                nbytes += buffer.len();
286            }
287        }
288        nbytes
289    }
290}
291
292mod private {
293    use super::*;
294
295    pub trait Sealed {}
296
297    impl<V: VTable> Sealed for ArrayAdapter<V> {}
298    impl Sealed for Arc<dyn Array> {}
299}
300
301/// Adapter struct used to lift the [`VTable`] trait into an object-safe [`Array`]
302/// implementation.
303///
304/// Since this is a unit struct with `repr(transparent)`, we are able to turn un-adapted array
305/// structs into [`dyn Array`] using some cheeky casting inside [`std::ops::Deref`] and
306/// [`AsRef`]. See the `vtable!` macro for more details.
307#[repr(transparent)]
308pub struct ArrayAdapter<V: VTable>(V::Array);
309
310impl<V: VTable> Debug for ArrayAdapter<V> {
311    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
312        self.0.fmt(f)
313    }
314}
315
316impl<V: VTable> Array for ArrayAdapter<V> {
317    fn as_any(&self) -> &dyn Any {
318        self
319    }
320
321    fn to_array(&self) -> ArrayRef {
322        Arc::new(ArrayAdapter::<V>(self.0.clone()))
323    }
324
325    fn len(&self) -> usize {
326        <V::ArrayVTable as ArrayVTable<V>>::len(&self.0)
327    }
328
329    fn dtype(&self) -> &DType {
330        <V::ArrayVTable as ArrayVTable<V>>::dtype(&self.0)
331    }
332
333    fn encoding(&self) -> EncodingRef {
334        V::encoding(&self.0)
335    }
336
337    fn encoding_id(&self) -> EncodingId {
338        V::encoding(&self.0).id()
339    }
340
341    fn slice(&self, start: usize, stop: usize) -> VortexResult<ArrayRef> {
342        if start == 0 && stop == self.len() {
343            return Ok(self.to_array());
344        }
345
346        if start > self.len() {
347            vortex_bail!(OutOfBounds: start, 0, self.len());
348        }
349        if stop > self.len() {
350            vortex_bail!(OutOfBounds: stop, 0, self.len());
351        }
352        if start > stop {
353            vortex_bail!("start ({start}) must be <= stop ({stop})");
354        }
355
356        if start == stop {
357            return Ok(Canonical::empty(self.dtype()).into_array());
358        }
359
360        // We know that constant array don't need stats propagation, so we can avoid the overhead of
361        // computing derived stats and merging them in.
362        // TODO(ngates): skip the is_constant check here, it can force an expensive compute.
363        // TODO(ngates): provide a means to slice an array _without_ propagating stats.
364        let derived_stats = (!self.0.is_constant_opts(Cost::Negligible)).then(|| {
365            let stats = self.statistics().to_owned();
366
367            // an array that is not constant can become constant after slicing
368            let is_constant = stats.get_as::<bool>(Stat::IsConstant);
369            let is_sorted = stats.get_as::<bool>(Stat::IsSorted);
370            let is_strict_sorted = stats.get_as::<bool>(Stat::IsStrictSorted);
371
372            let mut stats = stats.keep_inexact_stats(&[
373                Stat::Max,
374                Stat::Min,
375                Stat::NullCount,
376                Stat::UncompressedSizeInBytes,
377            ]);
378
379            if is_constant == Some(Precision::Exact(true)) {
380                stats.set(Stat::IsConstant, Precision::exact(true));
381            }
382            if is_sorted == Some(Precision::Exact(true)) {
383                stats.set(Stat::IsSorted, Precision::exact(true));
384            }
385            if is_strict_sorted == Some(Precision::Exact(true)) {
386                stats.set(Stat::IsStrictSorted, Precision::exact(true));
387            }
388
389            stats
390        });
391
392        let sliced = <V::OperationsVTable as OperationsVTable<V>>::slice(&self.0, start, stop)?;
393
394        assert_eq!(
395            sliced.len(),
396            stop - start,
397            "Slice length mismatch {}",
398            self.encoding_id()
399        );
400        assert_eq!(
401            sliced.dtype(),
402            self.dtype(),
403            "Slice dtype mismatch {}",
404            self.encoding_id()
405        );
406
407        if let Some(derived_stats) = derived_stats {
408            let mut stats = sliced.statistics().to_owned();
409            stats.combine_sets(&derived_stats, self.dtype())?;
410            for (stat, val) in stats.into_iter() {
411                sliced.statistics().set(stat, val)
412            }
413        }
414
415        Ok(sliced)
416    }
417
418    fn scalar_at(&self, index: usize) -> VortexResult<Scalar> {
419        if index >= self.len() {
420            vortex_bail!(OutOfBounds: index, 0, self.len());
421        }
422        if self.is_invalid(index)? {
423            return Ok(Scalar::null(self.dtype().clone()));
424        }
425        let scalar = <V::OperationsVTable as OperationsVTable<V>>::scalar_at(&self.0, index)?;
426        assert_eq!(self.dtype(), scalar.dtype(), "Scalar dtype mismatch");
427        Ok(scalar)
428    }
429
430    fn is_valid(&self, index: usize) -> VortexResult<bool> {
431        if index >= self.len() {
432            vortex_bail!(OutOfBounds: index, 0, self.len());
433        }
434        <V::ValidityVTable as ValidityVTable<V>>::is_valid(&self.0, index)
435    }
436
437    fn is_invalid(&self, index: usize) -> VortexResult<bool> {
438        self.is_valid(index).map(|valid| !valid)
439    }
440
441    fn all_valid(&self) -> VortexResult<bool> {
442        <V::ValidityVTable as ValidityVTable<V>>::all_valid(&self.0)
443    }
444
445    fn all_invalid(&self) -> VortexResult<bool> {
446        <V::ValidityVTable as ValidityVTable<V>>::all_invalid(&self.0)
447    }
448
449    fn valid_count(&self) -> VortexResult<usize> {
450        if let Some(Precision::Exact(invalid_count)) =
451            self.statistics().get_as::<usize>(Stat::NullCount)
452        {
453            return Ok(self.len() - invalid_count);
454        }
455
456        let count = <V::ValidityVTable as ValidityVTable<V>>::valid_count(&self.0)?;
457        assert!(count <= self.len(), "Valid count exceeds array length");
458
459        self.statistics()
460            .set(Stat::NullCount, Precision::exact(self.len() - count));
461
462        Ok(count)
463    }
464
465    fn invalid_count(&self) -> VortexResult<usize> {
466        if let Some(Precision::Exact(invalid_count)) =
467            self.statistics().get_as::<usize>(Stat::NullCount)
468        {
469            return Ok(invalid_count);
470        }
471
472        let count = <V::ValidityVTable as ValidityVTable<V>>::invalid_count(&self.0)?;
473        assert!(count <= self.len(), "Invalid count exceeds array length");
474
475        self.statistics()
476            .set(Stat::NullCount, Precision::exact(count));
477
478        Ok(count)
479    }
480
481    fn validity_mask(&self) -> VortexResult<Mask> {
482        let mask = <V::ValidityVTable as ValidityVTable<V>>::validity_mask(&self.0)?;
483        assert_eq!(mask.len(), self.len(), "Validity mask length mismatch");
484        Ok(mask)
485    }
486
487    fn to_canonical(&self) -> VortexResult<Canonical> {
488        let canonical = <V::CanonicalVTable as CanonicalVTable<V>>::canonicalize(&self.0)?;
489        assert_eq!(
490            self.len(),
491            canonical.as_ref().len(),
492            "Canonical length mismatch {}. Expected {} but encoded into {}.",
493            self.encoding_id(),
494            self.len(),
495            canonical.as_ref().len()
496        );
497        assert_eq!(
498            self.dtype(),
499            canonical.as_ref().dtype(),
500            "Canonical dtype mismatch {}. Expected {} but encoded into {}.",
501            self.encoding_id(),
502            self.dtype(),
503            canonical.as_ref().dtype()
504        );
505        canonical.as_ref().statistics().inherit(self.statistics());
506        Ok(canonical)
507    }
508
509    fn append_to_builder(&self, builder: &mut dyn ArrayBuilder) -> VortexResult<()> {
510        if builder.dtype() != self.dtype() {
511            vortex_bail!(
512                "Builder dtype mismatch: expected {}, got {}",
513                self.dtype(),
514                builder.dtype(),
515            );
516        }
517        let len = builder.len();
518
519        <V::CanonicalVTable as CanonicalVTable<V>>::append_to_builder(&self.0, builder)?;
520        assert_eq!(
521            len + self.len(),
522            builder.len(),
523            "Builder length mismatch after writing array for encoding {}",
524            self.encoding_id(),
525        );
526        Ok(())
527    }
528
529    fn statistics(&self) -> StatsSetRef<'_> {
530        <V::ArrayVTable as ArrayVTable<V>>::stats(&self.0)
531    }
532
533    fn with_children(&self, children: &[ArrayRef]) -> VortexResult<ArrayRef> {
534        struct ReplacementChildren<'a> {
535            children: &'a [ArrayRef],
536        }
537
538        impl ArrayChildren for ReplacementChildren<'_> {
539            fn get(&self, index: usize, dtype: &DType, len: usize) -> VortexResult<ArrayRef> {
540                if index >= self.children.len() {
541                    vortex_bail!(OutOfBounds: index, 0, self.children.len());
542                }
543                let child = &self.children[index];
544                if child.len() != len {
545                    vortex_bail!(
546                        "Child length mismatch: expected {}, got {}",
547                        len,
548                        child.len()
549                    );
550                }
551                if child.dtype() != dtype {
552                    vortex_bail!(
553                        "Child dtype mismatch: expected {}, got {}",
554                        dtype,
555                        child.dtype()
556                    );
557                }
558                Ok(child.clone())
559            }
560
561            fn len(&self) -> usize {
562                self.children.len()
563            }
564        }
565
566        let metadata = self.metadata()?.ok_or_else(|| {
567            vortex_err!("Cannot replace children for arrays that do not support serialization")
568        })?;
569
570        // Replace the children of the array by re-building the array from parts.
571        self.encoding().build(
572            self.dtype(),
573            self.len(),
574            &metadata,
575            &self.buffers(),
576            &ReplacementChildren { children },
577        )
578    }
579
580    fn invoke(
581        &self,
582        compute_fn: &ComputeFn,
583        args: &InvocationArgs,
584    ) -> VortexResult<Option<Output>> {
585        <V::ComputeVTable as ComputeVTable<V>>::invoke(&self.0, compute_fn, args)
586    }
587}
588
589impl<V: VTable> ArrayVisitor for ArrayAdapter<V> {
590    fn children(&self) -> Vec<ArrayRef> {
591        struct ChildrenCollector {
592            children: Vec<ArrayRef>,
593        }
594
595        impl ArrayChildVisitor for ChildrenCollector {
596            fn visit_child(&mut self, _name: &str, array: &dyn Array) {
597                self.children.push(array.to_array());
598            }
599        }
600
601        let mut collector = ChildrenCollector {
602            children: Vec::new(),
603        };
604        <V::VisitorVTable as VisitorVTable<V>>::visit_children(&self.0, &mut collector);
605        collector.children
606    }
607
608    fn nchildren(&self) -> usize {
609        <V::VisitorVTable as VisitorVTable<V>>::nchildren(&self.0)
610    }
611
612    fn children_names(&self) -> Vec<String> {
613        struct ChildNameCollector {
614            names: Vec<String>,
615        }
616
617        impl ArrayChildVisitor for ChildNameCollector {
618            fn visit_child(&mut self, name: &str, _array: &dyn Array) {
619                self.names.push(name.to_string());
620            }
621        }
622
623        let mut collector = ChildNameCollector { names: Vec::new() };
624        <V::VisitorVTable as VisitorVTable<V>>::visit_children(&self.0, &mut collector);
625        collector.names
626    }
627
628    fn named_children(&self) -> Vec<(String, ArrayRef)> {
629        struct NamedChildrenCollector {
630            children: Vec<(String, ArrayRef)>,
631        }
632
633        impl ArrayChildVisitor for NamedChildrenCollector {
634            fn visit_child(&mut self, name: &str, array: &dyn Array) {
635                self.children.push((name.to_string(), array.to_array()));
636            }
637        }
638
639        let mut collector = NamedChildrenCollector {
640            children: Vec::new(),
641        };
642
643        <V::VisitorVTable as VisitorVTable<V>>::visit_children(&self.0, &mut collector);
644        collector.children
645    }
646
647    fn buffers(&self) -> Vec<ByteBuffer> {
648        struct BufferCollector {
649            buffers: Vec<ByteBuffer>,
650        }
651
652        impl ArrayBufferVisitor for BufferCollector {
653            fn visit_buffer(&mut self, buffer: &ByteBuffer) {
654                self.buffers.push(buffer.clone());
655            }
656        }
657
658        let mut collector = BufferCollector {
659            buffers: Vec::new(),
660        };
661        <V::VisitorVTable as VisitorVTable<V>>::visit_buffers(&self.0, &mut collector);
662        collector.buffers
663    }
664
665    fn nbuffers(&self) -> usize {
666        <V::VisitorVTable as VisitorVTable<V>>::nbuffers(&self.0)
667    }
668
669    fn metadata(&self) -> VortexResult<Option<Vec<u8>>> {
670        Ok(<V::SerdeVTable as SerdeVTable<V>>::metadata(&self.0)?.map(|m| m.serialize()))
671    }
672
673    fn metadata_fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
674        match <V::SerdeVTable as SerdeVTable<V>>::metadata(&self.0) {
675            Err(e) => write!(f, "<serde error: {e}>"),
676            Ok(None) => write!(f, "<serde not supported>"),
677            Ok(Some(metadata)) => Debug::fmt(&metadata, f),
678        }
679    }
680}