Skip to main content

vortex_array/arrays/struct_/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use itertools::Itertools;
7use kernel::PARENT_KERNELS;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13use vortex_session::VortexSession;
14
15use crate::ArrayRef;
16use crate::EmptyMetadata;
17use crate::ExecutionCtx;
18use crate::ExecutionResult;
19use crate::arrays::StructArray;
20use crate::arrays::struct_::compute::rules::PARENT_RULES;
21use crate::buffer::BufferHandle;
22use crate::dtype::DType;
23use crate::serde::ArrayChildren;
24use crate::validity::Validity;
25use crate::vtable;
26use crate::vtable::VTable;
27use crate::vtable::ValidityVTableFromValidityHelper;
28use crate::vtable::validity_nchildren;
29use crate::vtable::validity_to_child;
30mod kernel;
31mod operations;
32mod validity;
33use std::hash::Hash;
34
35use crate::Precision;
36use crate::hash::ArrayEq;
37use crate::hash::ArrayHash;
38use crate::stats::StatsSetRef;
39use crate::vtable::ArrayId;
40
41vtable!(Struct);
42
43impl VTable for Struct {
44    type Array = StructArray;
45
46    type Metadata = EmptyMetadata;
47    type OperationsVTable = Self;
48    type ValidityVTable = ValidityVTableFromValidityHelper;
49    fn vtable(_array: &Self::Array) -> &Self {
50        &Struct
51    }
52
53    fn id(&self) -> ArrayId {
54        Self::ID
55    }
56
57    fn len(array: &StructArray) -> usize {
58        array.len
59    }
60
61    fn dtype(array: &StructArray) -> &DType {
62        &array.dtype
63    }
64
65    fn stats(array: &StructArray) -> StatsSetRef<'_> {
66        array.stats_set.to_ref(array.as_ref())
67    }
68
69    fn array_hash<H: std::hash::Hasher>(array: &StructArray, state: &mut H, precision: Precision) {
70        array.len.hash(state);
71        array.dtype.hash(state);
72        for field in array.fields.iter() {
73            field.array_hash(state, precision);
74        }
75        array.validity.array_hash(state, precision);
76    }
77
78    fn array_eq(array: &StructArray, other: &StructArray, precision: Precision) -> bool {
79        array.len == other.len
80            && array.dtype == other.dtype
81            && array.fields.len() == other.fields.len()
82            && array
83                .fields
84                .iter()
85                .zip(other.fields.iter())
86                .all(|(a, b)| a.array_eq(b, precision))
87            && array.validity.array_eq(&other.validity, precision)
88    }
89
90    fn nbuffers(_array: &StructArray) -> usize {
91        0
92    }
93
94    fn buffer(_array: &StructArray, idx: usize) -> BufferHandle {
95        vortex_panic!("StructArray buffer index {idx} out of bounds")
96    }
97
98    fn buffer_name(_array: &StructArray, idx: usize) -> Option<String> {
99        vortex_panic!("StructArray buffer_name index {idx} out of bounds")
100    }
101
102    fn nchildren(array: &StructArray) -> usize {
103        validity_nchildren(&array.validity) + array.unmasked_fields().len()
104    }
105
106    fn child(array: &StructArray, idx: usize) -> ArrayRef {
107        let vc = validity_nchildren(&array.validity);
108        if idx < vc {
109            validity_to_child(&array.validity, array.len())
110                .vortex_expect("StructArray validity child out of bounds")
111        } else {
112            array.unmasked_fields()[idx - vc].clone()
113        }
114    }
115
116    fn child_name(array: &StructArray, idx: usize) -> String {
117        let vc = validity_nchildren(&array.validity);
118        if idx < vc {
119            "validity".to_string()
120        } else {
121            array.names()[idx - vc].as_ref().to_string()
122        }
123    }
124
125    fn metadata(_array: &StructArray) -> VortexResult<Self::Metadata> {
126        Ok(EmptyMetadata)
127    }
128
129    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
130        Ok(Some(vec![]))
131    }
132
133    fn deserialize(
134        _bytes: &[u8],
135        _dtype: &DType,
136        _len: usize,
137        _buffers: &[BufferHandle],
138        _session: &VortexSession,
139    ) -> VortexResult<Self::Metadata> {
140        Ok(EmptyMetadata)
141    }
142
143    fn build(
144        dtype: &DType,
145        len: usize,
146        _metadata: &Self::Metadata,
147        _buffers: &[BufferHandle],
148        children: &dyn ArrayChildren,
149    ) -> VortexResult<StructArray> {
150        let DType::Struct(struct_dtype, nullability) = dtype else {
151            vortex_bail!("Expected struct dtype, found {:?}", dtype)
152        };
153
154        let (validity, non_data_children) = if children.len() == struct_dtype.nfields() {
155            (Validity::from(*nullability), 0_usize)
156        } else if children.len() == struct_dtype.nfields() + 1 {
157            // Validity is the first child if it exists.
158            let validity = children.get(0, &Validity::DTYPE, len)?;
159            (Validity::Array(validity), 1_usize)
160        } else {
161            vortex_bail!(
162                "Expected {} or {} children, found {}",
163                struct_dtype.nfields(),
164                struct_dtype.nfields() + 1,
165                children.len()
166            );
167        };
168
169        let children: Vec<_> = (0..struct_dtype.nfields())
170            .map(|i| {
171                let child_dtype = struct_dtype
172                    .field_by_index(i)
173                    .vortex_expect("no out of bounds");
174                children.get(non_data_children + i, &child_dtype, len)
175            })
176            .try_collect()?;
177
178        StructArray::try_new_with_dtype(children, struct_dtype.clone(), len, validity)
179    }
180
181    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
182        let DType::Struct(struct_dtype, _nullability) = &array.dtype else {
183            vortex_bail!("Expected struct dtype, found {:?}", array.dtype)
184        };
185
186        // First child is validity (if present), followed by fields
187        let (validity, non_data_children) = if children.len() == struct_dtype.nfields() {
188            (array.validity.clone(), 0_usize)
189        } else if children.len() == struct_dtype.nfields() + 1 {
190            (Validity::Array(children[0].clone()), 1_usize)
191        } else {
192            vortex_bail!(
193                "Expected {} or {} children, found {}",
194                struct_dtype.nfields(),
195                struct_dtype.nfields() + 1,
196                children.len()
197            );
198        };
199
200        let fields: Arc<[ArrayRef]> = children.into_iter().skip(non_data_children).collect();
201        vortex_ensure!(
202            fields.len() == struct_dtype.nfields(),
203            "Expected {} field children, found {}",
204            struct_dtype.nfields(),
205            fields.len()
206        );
207
208        array.fields = fields;
209        array.validity = validity;
210        Ok(())
211    }
212
213    fn execute(array: Arc<Self::Array>, _ctx: &mut ExecutionCtx) -> VortexResult<ExecutionResult> {
214        Ok(ExecutionResult::done_upcast::<Self>(array))
215    }
216
217    fn reduce_parent(
218        array: &Self::Array,
219        parent: &ArrayRef,
220        child_idx: usize,
221    ) -> VortexResult<Option<ArrayRef>> {
222        PARENT_RULES.evaluate(array, parent, child_idx)
223    }
224
225    fn execute_parent(
226        array: &Self::Array,
227        parent: &ArrayRef,
228        child_idx: usize,
229        ctx: &mut ExecutionCtx,
230    ) -> VortexResult<Option<ArrayRef>> {
231        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
232    }
233}
234
235#[derive(Clone, Debug)]
236pub struct Struct;
237
238impl Struct {
239    pub const ID: ArrayId = ArrayId::new_ref("vortex.struct");
240}