Skip to main content

vortex_array/arrays/struct_/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use itertools::Itertools;
7use kernel::PARENT_KERNELS;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13use vortex_session::VortexSession;
14
15use crate::ArrayRef;
16use crate::EmptyMetadata;
17use crate::ExecutionCtx;
18use crate::ExecutionResult;
19use crate::arrays::StructArray;
20use crate::arrays::struct_::compute::rules::PARENT_RULES;
21use crate::buffer::BufferHandle;
22use crate::dtype::DType;
23use crate::serde::ArrayChildren;
24use crate::validity::Validity;
25use crate::vtable;
26use crate::vtable::Array;
27use crate::vtable::VTable;
28use crate::vtable::ValidityVTableFromValidityHelper;
29use crate::vtable::validity_nchildren;
30use crate::vtable::validity_to_child;
31mod kernel;
32mod operations;
33mod validity;
34use std::hash::Hash;
35
36use crate::Precision;
37use crate::hash::ArrayEq;
38use crate::hash::ArrayHash;
39use crate::stats::StatsSetRef;
40use crate::vtable::ArrayId;
41
42vtable!(Struct);
43
44impl VTable for Struct {
45    type Array = StructArray;
46
47    type Metadata = EmptyMetadata;
48    type OperationsVTable = Self;
49    type ValidityVTable = ValidityVTableFromValidityHelper;
50    fn vtable(_array: &Self::Array) -> &Self {
51        &Struct
52    }
53
54    fn id(&self) -> ArrayId {
55        Self::ID
56    }
57
58    fn len(array: &StructArray) -> usize {
59        array.len
60    }
61
62    fn dtype(array: &StructArray) -> &DType {
63        &array.dtype
64    }
65
66    fn stats(array: &StructArray) -> StatsSetRef<'_> {
67        array.stats_set.to_ref(array.as_ref())
68    }
69
70    fn array_hash<H: std::hash::Hasher>(array: &StructArray, state: &mut H, precision: Precision) {
71        array.len.hash(state);
72        array.dtype.hash(state);
73        for field in array.fields.iter() {
74            field.array_hash(state, precision);
75        }
76        array.validity.array_hash(state, precision);
77    }
78
79    fn array_eq(array: &StructArray, other: &StructArray, precision: Precision) -> bool {
80        array.len == other.len
81            && array.dtype == other.dtype
82            && array.fields.len() == other.fields.len()
83            && array
84                .fields
85                .iter()
86                .zip(other.fields.iter())
87                .all(|(a, b)| a.array_eq(b, precision))
88            && array.validity.array_eq(&other.validity, precision)
89    }
90
91    fn nbuffers(_array: &StructArray) -> usize {
92        0
93    }
94
95    fn buffer(_array: &StructArray, idx: usize) -> BufferHandle {
96        vortex_panic!("StructArray buffer index {idx} out of bounds")
97    }
98
99    fn buffer_name(_array: &StructArray, idx: usize) -> Option<String> {
100        vortex_panic!("StructArray buffer_name index {idx} out of bounds")
101    }
102
103    fn nchildren(array: &StructArray) -> usize {
104        validity_nchildren(&array.validity) + array.unmasked_fields().len()
105    }
106
107    fn child(array: &StructArray, idx: usize) -> ArrayRef {
108        let vc = validity_nchildren(&array.validity);
109        if idx < vc {
110            validity_to_child(&array.validity, array.len())
111                .vortex_expect("StructArray validity child out of bounds")
112        } else {
113            array.unmasked_fields()[idx - vc].clone()
114        }
115    }
116
117    fn child_name(array: &StructArray, idx: usize) -> String {
118        let vc = validity_nchildren(&array.validity);
119        if idx < vc {
120            "validity".to_string()
121        } else {
122            array.names()[idx - vc].as_ref().to_string()
123        }
124    }
125
126    fn metadata(_array: &StructArray) -> VortexResult<Self::Metadata> {
127        Ok(EmptyMetadata)
128    }
129
130    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
131        Ok(Some(vec![]))
132    }
133
134    fn deserialize(
135        _bytes: &[u8],
136        _dtype: &DType,
137        _len: usize,
138        _buffers: &[BufferHandle],
139        _session: &VortexSession,
140    ) -> VortexResult<Self::Metadata> {
141        Ok(EmptyMetadata)
142    }
143
144    fn build(
145        dtype: &DType,
146        len: usize,
147        _metadata: &Self::Metadata,
148        _buffers: &[BufferHandle],
149        children: &dyn ArrayChildren,
150    ) -> VortexResult<StructArray> {
151        let DType::Struct(struct_dtype, nullability) = dtype else {
152            vortex_bail!("Expected struct dtype, found {:?}", dtype)
153        };
154
155        let (validity, non_data_children) = if children.len() == struct_dtype.nfields() {
156            (Validity::from(*nullability), 0_usize)
157        } else if children.len() == struct_dtype.nfields() + 1 {
158            // Validity is the first child if it exists.
159            let validity = children.get(0, &Validity::DTYPE, len)?;
160            (Validity::Array(validity), 1_usize)
161        } else {
162            vortex_bail!(
163                "Expected {} or {} children, found {}",
164                struct_dtype.nfields(),
165                struct_dtype.nfields() + 1,
166                children.len()
167            );
168        };
169
170        let children: Vec<_> = (0..struct_dtype.nfields())
171            .map(|i| {
172                let child_dtype = struct_dtype
173                    .field_by_index(i)
174                    .vortex_expect("no out of bounds");
175                children.get(non_data_children + i, &child_dtype, len)
176            })
177            .try_collect()?;
178
179        StructArray::try_new_with_dtype(children, struct_dtype.clone(), len, validity)
180    }
181
182    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
183        let DType::Struct(struct_dtype, _nullability) = &array.dtype else {
184            vortex_bail!("Expected struct dtype, found {:?}", array.dtype)
185        };
186
187        // First child is validity (if present), followed by fields
188        let (validity, non_data_children) = if children.len() == struct_dtype.nfields() {
189            (array.validity.clone(), 0_usize)
190        } else if children.len() == struct_dtype.nfields() + 1 {
191            (Validity::Array(children[0].clone()), 1_usize)
192        } else {
193            vortex_bail!(
194                "Expected {} or {} children, found {}",
195                struct_dtype.nfields(),
196                struct_dtype.nfields() + 1,
197                children.len()
198            );
199        };
200
201        let fields: Arc<[ArrayRef]> = children.into_iter().skip(non_data_children).collect();
202        vortex_ensure!(
203            fields.len() == struct_dtype.nfields(),
204            "Expected {} field children, found {}",
205            struct_dtype.nfields(),
206            fields.len()
207        );
208
209        array.fields = fields;
210        array.validity = validity;
211        Ok(())
212    }
213
214    fn execute(array: Arc<Array<Self>>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
215        Ok(ExecutionResult::done(array))
216    }
217
218    fn reduce_parent(
219        array: &Array<Self>,
220        parent: &ArrayRef,
221        child_idx: usize,
222    ) -> VortexResult<Option<ArrayRef>> {
223        PARENT_RULES.evaluate(array, parent, child_idx)
224    }
225
226    fn execute_parent(
227        array: &Array<Self>,
228        parent: &ArrayRef,
229        child_idx: usize,
230        ctx: &mut ExecutionCtx,
231    ) -> VortexResult<Option<ArrayRef>> {
232        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
233    }
234}
235
236#[derive(Clone, Debug)]
237pub struct Struct;
238
239impl Struct {
240    pub const ID: ArrayId = ArrayId::new_ref("vortex.struct");
241}