Skip to main content

vortex_array/vtable/
dyn_.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::type_name;
5use std::fmt;
6use std::fmt::Debug;
7use std::fmt::Formatter;
8use std::marker::PhantomData;
9
10use arcref::ArcRef;
11use vortex_dtype::DType;
12use vortex_error::VortexExpect;
13use vortex_error::VortexResult;
14use vortex_error::vortex_ensure;
15use vortex_session::VortexSession;
16
17use crate::Array;
18use crate::ArrayAdapter;
19use crate::ArrayRef;
20use crate::buffer::BufferHandle;
21use crate::executor::ExecutionCtx;
22use crate::serde::ArrayChildren;
23use crate::vtable::VTable;
24
25/// ArrayId is a globally unique name for the array's vtable.
26pub type ArrayId = ArcRef<str>;
27
28/// Dynamically typed vtable trait.
29///
30/// This trait is sealed, therefore users should implement the strongly typed [`VTable`] trait
31/// instead. The [`ArrayVTableExt::vtable`] function can be used to lift the implementation into
32/// this object-safe form.
33///
34/// This trait contains the implementation API for Vortex arrays, allowing us to keep the public
35/// [`Array`] trait API to a minimum.
36pub trait DynVTable: 'static + private::Sealed + Send + Sync + Debug {
37    #[allow(clippy::too_many_arguments)]
38    fn build(
39        &self,
40        id: ArrayId,
41        dtype: &DType,
42        len: usize,
43        metadata: &[u8],
44        buffers: &[BufferHandle],
45        children: &dyn ArrayChildren,
46        session: &VortexSession,
47    ) -> VortexResult<ArrayRef>;
48    fn with_children(&self, array: &dyn Array, children: Vec<ArrayRef>) -> VortexResult<ArrayRef>;
49
50    /// See [`VTable::reduce`]
51    fn reduce(&self, array: &ArrayRef) -> VortexResult<Option<ArrayRef>>;
52
53    /// See [`VTable::reduce_parent`]
54    fn reduce_parent(
55        &self,
56        array: &ArrayRef,
57        parent: &ArrayRef,
58        child_idx: usize,
59    ) -> VortexResult<Option<ArrayRef>>;
60
61    /// See [`VTable::execute`]
62    fn execute(&self, array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef>;
63
64    /// See [`VTable::execute_parent`]
65    fn execute_parent(
66        &self,
67        array: &ArrayRef,
68        parent: &ArrayRef,
69        child_idx: usize,
70        ctx: &mut ExecutionCtx,
71    ) -> VortexResult<Option<ArrayRef>>;
72}
73
74/// Adapter struct used to lift the [`VTable`] trait into an object-safe [`DynVTable`]
75/// implementation.
76struct ArrayVTableAdapter<V: VTable>(PhantomData<V>);
77
78impl<V: VTable> DynVTable for ArrayVTableAdapter<V> {
79    fn build(
80        &self,
81        _id: ArrayId,
82        dtype: &DType,
83        len: usize,
84        metadata: &[u8],
85        buffers: &[BufferHandle],
86        children: &dyn ArrayChildren,
87        session: &VortexSession,
88    ) -> VortexResult<ArrayRef> {
89        let metadata = V::deserialize(metadata, dtype, len, buffers, session)?;
90        let array = V::build(dtype, len, &metadata, buffers, children)?;
91        assert_eq!(array.len(), len, "Array length mismatch after building");
92        assert_eq!(array.dtype(), dtype, "Array dtype mismatch after building");
93        Ok(array.to_array())
94    }
95
96    fn with_children(&self, array: &dyn Array, children: Vec<ArrayRef>) -> VortexResult<ArrayRef> {
97        let mut array = array.as_::<V>().clone();
98        V::with_children(&mut array, children)?;
99        Ok(array.to_array())
100    }
101
102    fn reduce(&self, array: &ArrayRef) -> VortexResult<Option<ArrayRef>> {
103        let Some(reduced) = V::reduce(downcast::<V>(array))? else {
104            return Ok(None);
105        };
106        vortex_ensure!(
107            reduced.len() == array.len(),
108            "Reduced array length mismatch from {} to {}",
109            array.encoding_id(),
110            reduced.encoding_id()
111        );
112        vortex_ensure!(
113            reduced.dtype() == array.dtype(),
114            "Reduced array dtype mismatch from {} to {}",
115            array.encoding_id(),
116            reduced.encoding_id()
117        );
118        Ok(Some(reduced))
119    }
120
121    fn reduce_parent(
122        &self,
123        array: &ArrayRef,
124        parent: &ArrayRef,
125        child_idx: usize,
126    ) -> VortexResult<Option<ArrayRef>> {
127        let Some(reduced) = V::reduce_parent(downcast::<V>(array), parent, child_idx)? else {
128            return Ok(None);
129        };
130
131        vortex_ensure!(
132            reduced.len() == parent.len(),
133            "Reduced array length mismatch from {} to {}",
134            parent.encoding_id(),
135            reduced.encoding_id()
136        );
137        vortex_ensure!(
138            reduced.dtype() == parent.dtype(),
139            "Reduced array dtype mismatch from {} to {}",
140            parent.encoding_id(),
141            reduced.encoding_id()
142        );
143
144        Ok(Some(reduced))
145    }
146
147    fn execute(&self, array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
148        let result = V::execute(downcast::<V>(array), ctx)?;
149
150        if cfg!(debug_assertions) {
151            vortex_ensure!(
152                result.as_ref().len() == array.len(),
153                "Result length mismatch for {:?}",
154                self
155            );
156            vortex_ensure!(
157                result.as_ref().dtype() == array.dtype(),
158                "Executed canonical dtype mismatch for {:?}",
159                self
160            );
161        }
162
163        // TODO(ngates): do we want to do this on every execution? We used to in to_canonical.
164        result
165            .as_ref()
166            .statistics()
167            .inherit_from(array.statistics());
168
169        Ok(result)
170    }
171
172    fn execute_parent(
173        &self,
174        array: &ArrayRef,
175        parent: &ArrayRef,
176        child_idx: usize,
177        ctx: &mut ExecutionCtx,
178    ) -> VortexResult<Option<ArrayRef>> {
179        let Some(result) = V::execute_parent(downcast::<V>(array), parent, child_idx, ctx)? else {
180            return Ok(None);
181        };
182
183        if cfg!(debug_assertions) {
184            vortex_ensure!(
185                result.as_ref().len() == parent.len(),
186                "Executed parent canonical length mismatch"
187            );
188            vortex_ensure!(
189                result.as_ref().dtype() == parent.dtype(),
190                "Executed parent canonical dtype mismatch"
191            );
192        }
193
194        Ok(Some(result))
195    }
196}
197
198fn downcast<V: VTable>(array: &ArrayRef) -> &V::Array {
199    array
200        .as_any()
201        .downcast_ref::<ArrayAdapter<V>>()
202        .vortex_expect("Failed to downcast array to expected encoding type")
203        .as_inner()
204}
205
206impl<V: VTable> Debug for ArrayVTableAdapter<V> {
207    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
208        write!(f, "Encoding<{}>", type_name::<V>())
209    }
210}
211
212impl<V: VTable> From<V> for &'static dyn DynVTable {
213    fn from(_vtable: V) -> Self {
214        const { &ArrayVTableAdapter::<V>(PhantomData) }
215    }
216}
217
218pub trait ArrayVTableExt {
219    /// Wraps the vtable into an [`DynVTable`] by static reference.
220    fn vtable() -> &'static dyn DynVTable;
221}
222
223impl<V: VTable> ArrayVTableExt for V {
224    fn vtable() -> &'static dyn DynVTable {
225        const { &ArrayVTableAdapter::<V>(PhantomData) }
226    }
227}
228
229mod private {
230    use super::ArrayVTableAdapter;
231    use crate::vtable::VTable;
232
233    pub trait Sealed {}
234    impl<V: VTable> Sealed for ArrayVTableAdapter<V> {}
235}