Skip to main content

vortex_array/arrays/struct_/vtable/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::sync::Arc;
5
6use itertools::Itertools;
7use kernel::PARENT_KERNELS;
8use vortex_error::VortexExpect;
9use vortex_error::VortexResult;
10use vortex_error::vortex_bail;
11use vortex_error::vortex_ensure;
12use vortex_error::vortex_panic;
13use vortex_session::VortexSession;
14
15use crate::ArrayRef;
16use crate::EmptyMetadata;
17use crate::ExecutionCtx;
18use crate::arrays::struct_::StructArray;
19use crate::arrays::struct_::compute::rules::PARENT_RULES;
20use crate::buffer::BufferHandle;
21use crate::dtype::DType;
22use crate::serde::ArrayChildren;
23use crate::validity::Validity;
24use crate::vtable;
25use crate::vtable::VTable;
26use crate::vtable::ValidityVTableFromValidityHelper;
27use crate::vtable::validity_nchildren;
28use crate::vtable::validity_to_child;
29mod kernel;
30mod operations;
31mod validity;
32use std::hash::Hash;
33
34use crate::Precision;
35use crate::hash::ArrayEq;
36use crate::hash::ArrayHash;
37use crate::stats::StatsSetRef;
38use crate::vtable::ArrayId;
39
40vtable!(Struct);
41
42impl VTable for StructVTable {
43    type Array = StructArray;
44
45    type Metadata = EmptyMetadata;
46    type OperationsVTable = Self;
47    type ValidityVTable = ValidityVTableFromValidityHelper;
48    fn id(_array: &Self::Array) -> ArrayId {
49        Self::ID
50    }
51
52    fn len(array: &StructArray) -> usize {
53        array.len
54    }
55
56    fn dtype(array: &StructArray) -> &DType {
57        &array.dtype
58    }
59
60    fn stats(array: &StructArray) -> StatsSetRef<'_> {
61        array.stats_set.to_ref(array.as_ref())
62    }
63
64    fn array_hash<H: std::hash::Hasher>(array: &StructArray, state: &mut H, precision: Precision) {
65        array.len.hash(state);
66        array.dtype.hash(state);
67        for field in array.fields.iter() {
68            field.array_hash(state, precision);
69        }
70        array.validity.array_hash(state, precision);
71    }
72
73    fn array_eq(array: &StructArray, other: &StructArray, precision: Precision) -> bool {
74        array.len == other.len
75            && array.dtype == other.dtype
76            && array.fields.len() == other.fields.len()
77            && array
78                .fields
79                .iter()
80                .zip(other.fields.iter())
81                .all(|(a, b)| a.array_eq(b, precision))
82            && array.validity.array_eq(&other.validity, precision)
83    }
84
85    fn nbuffers(_array: &StructArray) -> usize {
86        0
87    }
88
89    fn buffer(_array: &StructArray, idx: usize) -> BufferHandle {
90        vortex_panic!("StructArray buffer index {idx} out of bounds")
91    }
92
93    fn buffer_name(_array: &StructArray, idx: usize) -> Option<String> {
94        vortex_panic!("StructArray buffer_name index {idx} out of bounds")
95    }
96
97    fn nchildren(array: &StructArray) -> usize {
98        validity_nchildren(&array.validity) + array.unmasked_fields().len()
99    }
100
101    fn child(array: &StructArray, idx: usize) -> ArrayRef {
102        let vc = validity_nchildren(&array.validity);
103        if idx < vc {
104            validity_to_child(&array.validity, array.len())
105                .vortex_expect("StructArray validity child out of bounds")
106        } else {
107            array.unmasked_fields()[idx - vc].clone()
108        }
109    }
110
111    fn child_name(array: &StructArray, idx: usize) -> String {
112        let vc = validity_nchildren(&array.validity);
113        if idx < vc {
114            "validity".to_string()
115        } else {
116            array.names()[idx - vc].as_ref().to_string()
117        }
118    }
119
120    fn metadata(_array: &StructArray) -> VortexResult<Self::Metadata> {
121        Ok(EmptyMetadata)
122    }
123
124    fn serialize(_metadata: Self::Metadata) -> VortexResult<Option<Vec<u8>>> {
125        Ok(Some(vec![]))
126    }
127
128    fn deserialize(
129        _bytes: &[u8],
130        _dtype: &DType,
131        _len: usize,
132        _buffers: &[BufferHandle],
133        _session: &VortexSession,
134    ) -> VortexResult<Self::Metadata> {
135        Ok(EmptyMetadata)
136    }
137
138    fn build(
139        dtype: &DType,
140        len: usize,
141        _metadata: &Self::Metadata,
142        _buffers: &[BufferHandle],
143        children: &dyn ArrayChildren,
144    ) -> VortexResult<StructArray> {
145        let DType::Struct(struct_dtype, nullability) = dtype else {
146            vortex_bail!("Expected struct dtype, found {:?}", dtype)
147        };
148
149        let (validity, non_data_children) = if children.len() == struct_dtype.nfields() {
150            (Validity::from(*nullability), 0_usize)
151        } else if children.len() == struct_dtype.nfields() + 1 {
152            // Validity is the first child if it exists.
153            let validity = children.get(0, &Validity::DTYPE, len)?;
154            (Validity::Array(validity), 1_usize)
155        } else {
156            vortex_bail!(
157                "Expected {} or {} children, found {}",
158                struct_dtype.nfields(),
159                struct_dtype.nfields() + 1,
160                children.len()
161            );
162        };
163
164        let children: Vec<_> = (0..struct_dtype.nfields())
165            .map(|i| {
166                let child_dtype = struct_dtype
167                    .field_by_index(i)
168                    .vortex_expect("no out of bounds");
169                children.get(non_data_children + i, &child_dtype, len)
170            })
171            .try_collect()?;
172
173        StructArray::try_new_with_dtype(children, struct_dtype.clone(), len, validity)
174    }
175
176    fn with_children(array: &mut Self::Array, children: Vec<ArrayRef>) -> VortexResult<()> {
177        let DType::Struct(struct_dtype, _nullability) = &array.dtype else {
178            vortex_bail!("Expected struct dtype, found {:?}", array.dtype)
179        };
180
181        // First child is validity (if present), followed by fields
182        let (validity, non_data_children) = if children.len() == struct_dtype.nfields() {
183            (array.validity.clone(), 0_usize)
184        } else if children.len() == struct_dtype.nfields() + 1 {
185            (Validity::Array(children[0].clone()), 1_usize)
186        } else {
187            vortex_bail!(
188                "Expected {} or {} children, found {}",
189                struct_dtype.nfields(),
190                struct_dtype.nfields() + 1,
191                children.len()
192            );
193        };
194
195        let fields: Arc<[ArrayRef]> = children.into_iter().skip(non_data_children).collect();
196        vortex_ensure!(
197            fields.len() == struct_dtype.nfields(),
198            "Expected {} field children, found {}",
199            struct_dtype.nfields(),
200            fields.len()
201        );
202
203        array.fields = fields;
204        array.validity = validity;
205        Ok(())
206    }
207
208    fn execute(array: &Self::Array, _ctx: &mut ExecutionCtx) -> VortexResult<ArrayRef> {
209        Ok(array.to_array())
210    }
211
212    fn reduce_parent(
213        array: &Self::Array,
214        parent: &ArrayRef,
215        child_idx: usize,
216    ) -> VortexResult<Option<ArrayRef>> {
217        PARENT_RULES.evaluate(array, parent, child_idx)
218    }
219
220    fn execute_parent(
221        array: &Self::Array,
222        parent: &ArrayRef,
223        child_idx: usize,
224        ctx: &mut ExecutionCtx,
225    ) -> VortexResult<Option<ArrayRef>> {
226        PARENT_KERNELS.execute(array, parent, child_idx, ctx)
227    }
228}
229
230#[derive(Debug)]
231pub struct StructVTable;
232
233impl StructVTable {
234    pub const ID: ArrayId = ArrayId::new_ref("vortex.struct");
235}