vortex_layout/layouts/struct_/
mod.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4mod reader;
5pub mod writer;
6
7use std::sync::Arc;
8
9use reader::StructReader;
10use vortex_array::{ArrayContext, DeserializeMetadata, EmptyMetadata};
11use vortex_dtype::{DType, Field, FieldMask, Nullability, StructFields};
12use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_ensure, vortex_err};
13
14use crate::children::{LayoutChildren, OwnedLayoutChildren};
15use crate::segments::{SegmentId, SegmentSource};
16use crate::{
17    LayoutChildType, LayoutEncodingRef, LayoutId, LayoutReaderRef, LayoutRef, VTable, vtable,
18};
19
20vtable!(Struct);
21
22impl VTable for StructVTable {
23    type Layout = StructLayout;
24    type Encoding = StructLayoutEncoding;
25    type Metadata = EmptyMetadata;
26
27    fn id(_encoding: &Self::Encoding) -> LayoutId {
28        LayoutId::new_ref("vortex.struct")
29    }
30
31    fn encoding(_layout: &Self::Layout) -> LayoutEncodingRef {
32        LayoutEncodingRef::new_ref(StructLayoutEncoding.as_ref())
33    }
34
35    fn row_count(layout: &Self::Layout) -> u64 {
36        layout.row_count
37    }
38
39    fn dtype(layout: &Self::Layout) -> &DType {
40        &layout.dtype
41    }
42
43    fn metadata(_layout: &Self::Layout) -> Self::Metadata {
44        EmptyMetadata
45    }
46
47    fn segment_ids(_layout: &Self::Layout) -> Vec<SegmentId> {
48        vec![]
49    }
50
51    fn nchildren(layout: &Self::Layout) -> usize {
52        let validity_children = if layout.dtype.is_nullable() { 1 } else { 0 };
53        layout.struct_fields().nfields() + validity_children
54    }
55
56    fn child(layout: &Self::Layout, index: usize) -> VortexResult<LayoutRef> {
57        let schema_index = if layout.dtype.is_nullable() {
58            index.saturating_sub(1)
59        } else {
60            index
61        };
62
63        let child_dtype = if index == 0 && layout.dtype.is_nullable() {
64            DType::Bool(Nullability::NonNullable)
65        } else {
66            layout
67                .struct_fields()
68                .field_by_index(schema_index)
69                .ok_or_else(|| vortex_err!("Missing field {schema_index}"))?
70        };
71
72        layout.children.child(index, &child_dtype)
73    }
74
75    fn child_type(layout: &Self::Layout, idx: usize) -> LayoutChildType {
76        let schema_index = if layout.dtype.is_nullable() {
77            idx.saturating_sub(1)
78        } else {
79            idx
80        };
81
82        if idx == 0 && layout.dtype.is_nullable() {
83            LayoutChildType::Auxiliary("validity".into())
84        } else {
85            LayoutChildType::Field(
86                layout
87                    .struct_fields()
88                    .field_name(schema_index)
89                    .vortex_expect("Field index out of bounds")
90                    .clone(),
91            )
92        }
93    }
94
95    fn new_reader(
96        layout: &Self::Layout,
97        name: Arc<str>,
98        segment_source: Arc<dyn SegmentSource>,
99    ) -> VortexResult<LayoutReaderRef> {
100        Ok(Arc::new(StructReader::try_new(
101            layout.clone(),
102            name,
103            segment_source,
104        )?))
105    }
106
107    #[cfg(gpu_unstable)]
108    fn new_gpu_reader(
109        layout: &Self::Layout,
110        name: Arc<str>,
111        segment_source: Arc<dyn SegmentSource>,
112        ctx: Arc<cudarc::driver::CudaContext>,
113    ) -> VortexResult<crate::gpu::GpuLayoutReaderRef> {
114        Ok(Arc::new(
115            crate::gpu::layouts::struct_::GpuStructReader::try_new(
116                layout.clone(),
117                name,
118                segment_source,
119                ctx,
120            )?,
121        ))
122    }
123
124    fn build(
125        _encoding: &Self::Encoding,
126        dtype: &DType,
127        row_count: u64,
128        _metadata: &<Self::Metadata as DeserializeMetadata>::Output,
129        _segment_ids: Vec<SegmentId>,
130        children: &dyn LayoutChildren,
131        _ctx: ArrayContext,
132    ) -> VortexResult<Self::Layout> {
133        let struct_dt = dtype
134            .as_struct_fields_opt()
135            .ok_or_else(|| vortex_err!("Expected struct dtype"))?;
136
137        let expected_children = struct_dt.nfields() + (dtype.is_nullable() as usize);
138        vortex_ensure!(
139            children.nchildren() == expected_children,
140            "Struct layout has {} children, but dtype has {} fields",
141            children.nchildren(),
142            struct_dt.nfields()
143        );
144
145        Ok(StructLayout {
146            row_count,
147            dtype: dtype.clone(),
148            children: children.to_arc(),
149        })
150    }
151}
152
153#[derive(Debug)]
154pub struct StructLayoutEncoding;
155
156#[derive(Clone, Debug)]
157pub struct StructLayout {
158    row_count: u64,
159    dtype: DType,
160    children: Arc<dyn LayoutChildren>,
161}
162
163impl StructLayout {
164    pub fn new(row_count: u64, dtype: DType, children: Vec<LayoutRef>) -> Self {
165        Self {
166            row_count,
167            dtype,
168            children: OwnedLayoutChildren::layout_children(children),
169        }
170    }
171
172    pub fn struct_fields(&self) -> &StructFields {
173        self.dtype
174            .as_struct_fields_opt()
175            .vortex_expect("Struct layout dtype must be a struct")
176    }
177
178    #[inline]
179    pub fn row_count(&self) -> u64 {
180        self.row_count
181    }
182
183    #[inline]
184    pub fn children(&self) -> &Arc<dyn LayoutChildren> {
185        &self.children
186    }
187
188    pub fn matching_fields<F>(&self, field_mask: &[FieldMask], mut per_child: F) -> VortexResult<()>
189    where
190        F: FnMut(FieldMask, usize) -> VortexResult<()>,
191    {
192        // If the field mask contains an `All` fields, then enumerate all fields.
193        if field_mask.iter().any(|mask| mask.matches_all()) {
194            for idx in 0..self.struct_fields().nfields() {
195                per_child(FieldMask::All, idx)?;
196            }
197            return Ok(());
198        }
199
200        // Enumerate each field in the mask
201        for path in field_mask {
202            let Some(field) = path.starting_field()? else {
203                // skip fields not in mask
204                continue;
205            };
206            let Field::Name(field_name) = field else {
207                vortex_bail!("Expected field name, got {field:?}");
208            };
209            let idx = self
210                .struct_fields()
211                .find(field_name)
212                .ok_or_else(|| vortex_err!("Field not found: {field_name}"))?;
213
214            per_child(path.clone().step_into()?, idx)?;
215        }
216
217        Ok(())
218    }
219}