vortex_array/array/
implementation.rs

1use std::any::Any;
2use std::fmt::Debug;
3use std::sync::Arc;
4
5use vortex_dtype::DType;
6use vortex_error::{VortexResult, vortex_bail};
7use vortex_mask::Mask;
8
9use crate::array::canonical::ArrayCanonicalImpl;
10use crate::array::validity::ArrayValidityImpl;
11use crate::array::visitor::ArrayVisitorImpl;
12use crate::builders::ArrayBuilder;
13use crate::compute::{ComputeFn, InvocationArgs, Output};
14use crate::stats::{Precision, Stat, StatsSetRef};
15use crate::vtable::VTableRef;
16use crate::{
17    Array, ArrayRef, ArrayStatisticsImpl, ArrayVariantsImpl, ArrayVisitor, Canonical, Encoding,
18    EncodingId,
19};
20
21/// A trait used to encapsulate common implementation behaviour for a Vortex [`Array`].
22pub trait ArrayImpl:
23    'static
24    + Send
25    + Sync
26    + Debug
27    + Clone
28    + ArrayCanonicalImpl
29    + ArrayStatisticsImpl
30    + ArrayValidityImpl
31    + ArrayVariantsImpl
32    + ArrayVisitorImpl<<Self::Encoding as Encoding>::Metadata>
33{
34    type Encoding: Encoding;
35
36    fn _len(&self) -> usize;
37    fn _dtype(&self) -> &DType;
38    fn _vtable(&self) -> VTableRef;
39
40    /// Replace the children of this array with the given arrays.
41    ///
42    /// ## Pre-conditions
43    ///
44    /// - The number of given children matches the current number of children of the array.
45    fn _with_children(&self, children: &[ArrayRef]) -> VortexResult<Self>;
46
47    /// Dynamically invoke a kernel for the given compute function.
48    fn _invoke(
49        &self,
50        _compute_fn: &ComputeFn,
51        _args: &InvocationArgs,
52    ) -> VortexResult<Option<Output>> {
53        Ok(None)
54    }
55}
56
57impl<A: ArrayImpl + 'static> Array for A {
58    fn as_any(&self) -> &dyn Any {
59        self
60    }
61
62    fn as_any_arc(self: Arc<Self>) -> Arc<dyn Any + Send + Sync> {
63        self
64    }
65
66    fn to_array(&self) -> ArrayRef {
67        Arc::new(self.clone())
68    }
69
70    fn into_array(self) -> ArrayRef
71    where
72        Self: Sized,
73    {
74        Arc::new(self)
75    }
76
77    fn len(&self) -> usize {
78        ArrayImpl::_len(self)
79    }
80
81    fn dtype(&self) -> &DType {
82        ArrayImpl::_dtype(self)
83    }
84
85    fn encoding(&self) -> EncodingId {
86        self.vtable().id()
87    }
88
89    fn vtable(&self) -> VTableRef {
90        ArrayImpl::_vtable(self)
91    }
92
93    /// Returns whether the item at `index` is valid.
94    fn is_valid(&self, index: usize) -> VortexResult<bool> {
95        if index >= self.len() {
96            vortex_bail!("Index out of bounds: {} >= {}", index, self.len());
97        }
98        ArrayValidityImpl::_is_valid(self, index)
99    }
100
101    /// Returns whether the item at `index` is invalid.
102    fn is_invalid(&self, index: usize) -> VortexResult<bool> {
103        self.is_valid(index).map(|valid| !valid)
104    }
105
106    /// Returns whether all items in the array are valid.
107    ///
108    /// This is usually cheaper than computing a precise `valid_count`.
109    fn all_valid(&self) -> VortexResult<bool> {
110        ArrayValidityImpl::_all_valid(self)
111    }
112
113    /// Returns whether the array is all invalid.
114    ///
115    /// This is usually cheaper than computing a precise `invalid_count`.
116    fn all_invalid(&self) -> VortexResult<bool> {
117        ArrayValidityImpl::_all_invalid(self)
118    }
119
120    /// Returns the number of valid elements in the array.
121    fn valid_count(&self) -> VortexResult<usize> {
122        if let Some(Precision::Exact(invalid_count)) =
123            self.statistics().get_as::<usize>(Stat::NullCount)
124        {
125            return Ok(self.len() - invalid_count);
126        }
127
128        let count = ArrayValidityImpl::_valid_count(self)?;
129        assert!(count <= self.len(), "Valid count exceeds array length");
130
131        self.statistics()
132            .set(Stat::NullCount, Precision::exact(self.len() - count));
133
134        Ok(count)
135    }
136
137    /// Returns the number of invalid elements in the array.
138    fn invalid_count(&self) -> VortexResult<usize> {
139        if let Some(Precision::Exact(invalid_count)) =
140            self.statistics().get_as::<usize>(Stat::NullCount)
141        {
142            return Ok(invalid_count);
143        }
144
145        let count = ArrayValidityImpl::_invalid_count(self)?;
146        assert!(count <= self.len(), "Invalid count exceeds array length");
147
148        self.statistics()
149            .set(Stat::NullCount, Precision::exact(count));
150
151        Ok(count)
152    }
153
154    /// Returns the canonical validity mask for the array.
155    fn validity_mask(&self) -> VortexResult<Mask> {
156        let mask = ArrayValidityImpl::_validity_mask(self)?;
157        assert_eq!(mask.len(), self.len(), "Validity mask length mismatch");
158        Ok(mask)
159    }
160
161    /// Returns the canonical representation of the array.
162    fn to_canonical(&self) -> VortexResult<Canonical> {
163        let canonical = ArrayCanonicalImpl::_to_canonical(self)?;
164        assert_eq!(
165            canonical.as_ref().len(),
166            self.len(),
167            "Canonical length mismatch"
168        );
169        assert_eq!(
170            canonical.as_ref().dtype(),
171            self.dtype(),
172            "Canonical dtype mismatch"
173        );
174        canonical.as_ref().statistics().inherit(self.statistics());
175        Ok(canonical)
176    }
177
178    /// Writes the array into the canonical builder.
179    ///
180    /// The [`DType`] of the builder must match that of the array.
181    fn append_to_builder(&self, builder: &mut dyn ArrayBuilder) -> VortexResult<()> {
182        if builder.dtype() != self.dtype() {
183            vortex_bail!(
184                "Builder dtype mismatch: expected {}, got {}",
185                self.dtype(),
186                builder.dtype(),
187            );
188        }
189        let len = builder.len();
190
191        ArrayCanonicalImpl::_append_to_builder(self, builder)?;
192        assert_eq!(
193            len + self.len(),
194            builder.len(),
195            "Builder length mismatch after writing array for encoding {}",
196            self.encoding(),
197        );
198        Ok(())
199    }
200
201    fn statistics(&self) -> StatsSetRef<'_> {
202        self._stats_ref()
203    }
204
205    fn with_children(&self, children: &[ArrayRef]) -> VortexResult<ArrayRef> {
206        if self.nchildren() != children.len() {
207            vortex_bail!("Child count mismatch");
208        }
209
210        for (s, o) in self.children().iter().zip(children.iter()) {
211            assert_eq!(s.len(), o.len());
212        }
213
214        Ok(self._with_children(children)?.into_array())
215    }
216
217    fn invoke(
218        &self,
219        compute_fn: &ComputeFn,
220        args: &InvocationArgs,
221    ) -> VortexResult<Option<Output>> {
222        self._invoke(compute_fn, args)
223    }
224}