Skip to main content

vortex_array/arrays/shared/
vtable.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::hash::Hasher;
5
6use vortex_error::VortexExpect;
7use vortex_error::VortexResult;
8use vortex_error::vortex_panic;
9use vortex_session::VortexSession;
10use vortex_session::registry::CachedId;
11
12use crate::ArrayEq;
13use crate::ArrayHash;
14use crate::ArrayRef;
15use crate::Canonical;
16use crate::ExecutionCtx;
17use crate::ExecutionResult;
18use crate::Precision;
19use crate::array::Array;
20use crate::array::ArrayId;
21use crate::array::ArrayView;
22use crate::array::OperationsVTable;
23use crate::array::VTable;
24use crate::array::ValidityVTable;
25use crate::arrays::shared::SharedArrayExt;
26use crate::arrays::shared::SharedData;
27use crate::arrays::shared::array::SLOT_NAMES;
28use crate::buffer::BufferHandle;
29use crate::dtype::DType;
30use crate::scalar::Scalar;
31use crate::validity::Validity;
32
33/// A [`Shared`]-encoded Vortex array.
34pub type SharedArray = Array<Shared>;
35
36// TODO(ngates): consider hooking Shared into the iterative execution model. Cache either the
37//  most executed, or after each iteration, and return a shared cache for each execution.
38#[derive(Clone, Debug)]
39pub struct Shared;
40
41impl ArrayHash for SharedData {
42    fn array_hash<H: Hasher>(&self, _state: &mut H, _precision: Precision) {}
43}
44
45impl ArrayEq for SharedData {
46    fn array_eq(&self, _other: &Self, _precision: Precision) -> bool {
47        true
48    }
49}
50
51impl VTable for Shared {
52    type ArrayData = SharedData;
53    type OperationsVTable = Self;
54    type ValidityVTable = Self;
55    fn id(&self) -> ArrayId {
56        static ID: CachedId = CachedId::new("vortex.shared");
57        *ID
58    }
59
60    fn validate(
61        &self,
62        _data: &SharedData,
63        dtype: &DType,
64        len: usize,
65        slots: &[Option<ArrayRef>],
66    ) -> VortexResult<()> {
67        let source = slots[0]
68            .as_ref()
69            .vortex_expect("SharedArray source slot must be present");
70        vortex_error::vortex_ensure!(source.dtype() == dtype, "SharedArray dtype mismatch");
71        vortex_error::vortex_ensure!(source.len() == len, "SharedArray len mismatch");
72        Ok(())
73    }
74
75    fn nbuffers(_array: ArrayView<'_, Self>) -> usize {
76        0
77    }
78
79    fn buffer(_array: ArrayView<'_, Self>, _idx: usize) -> BufferHandle {
80        vortex_panic!("SharedArray has no buffers")
81    }
82
83    fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option<String> {
84        None
85    }
86
87    fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String {
88        SLOT_NAMES[idx].to_string()
89    }
90
91    fn serialize(
92        _array: ArrayView<'_, Self>,
93        _session: &VortexSession,
94    ) -> VortexResult<Option<Vec<u8>>> {
95        vortex_error::vortex_bail!("Shared array is not serializable")
96    }
97
98    fn deserialize(
99        &self,
100        _dtype: &DType,
101        _len: usize,
102        _metadata: &[u8],
103
104        _buffers: &[BufferHandle],
105        _children: &dyn crate::serde::ArrayChildren,
106        _session: &VortexSession,
107    ) -> VortexResult<crate::array::ArrayParts<Self>> {
108        vortex_error::vortex_bail!("Shared array is not serializable")
109    }
110
111    fn execute(array: Array<Self>, ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
112        array
113            .get_or_compute(|source| source.clone().execute::<Canonical>(ctx))
114            .map(ExecutionResult::done)
115    }
116}
117impl OperationsVTable<Shared> for Shared {
118    fn scalar_at(
119        array: ArrayView<'_, Shared>,
120        index: usize,
121        ctx: &mut ExecutionCtx,
122    ) -> VortexResult<Scalar> {
123        array.current_array_ref().execute_scalar(index, ctx)
124    }
125}
126
127impl ValidityVTable<Shared> for Shared {
128    fn validity(array: ArrayView<'_, Shared>) -> VortexResult<Validity> {
129        array.current_array_ref().validity()
130    }
131}