vortex_layout/layouts/struct_/
mod.rs

1mod eval_expr;
2mod reader;
3pub mod writer;
4
5use std::collections::BTreeSet;
6use std::sync::Arc;
7
8use reader::StructReader;
9use vortex_array::ArrayContext;
10use vortex_dtype::{DType, Field, FieldMask};
11use vortex_error::{VortexResult, vortex_bail};
12
13use crate::data::Layout;
14use crate::reader::{LayoutReader, LayoutReaderExt};
15use crate::segments::SegmentSource;
16use crate::vtable::LayoutVTable;
17use crate::{LayoutId, STRUCT_LAYOUT_ID};
18
19#[derive(Debug)]
20pub struct StructLayout;
21
22impl LayoutVTable for StructLayout {
23    fn id(&self) -> LayoutId {
24        STRUCT_LAYOUT_ID
25    }
26
27    fn reader(
28        &self,
29        layout: Layout,
30        segment_source: &Arc<dyn SegmentSource>,
31        ctx: &ArrayContext,
32    ) -> VortexResult<Arc<dyn LayoutReader>> {
33        Ok(StructReader::try_new(layout, segment_source.clone(), ctx.clone())?.into_arc())
34    }
35
36    fn register_splits(
37        &self,
38        layout: &Layout,
39        field_mask: &[FieldMask],
40        row_offset: u64,
41        splits: &mut BTreeSet<u64>,
42    ) -> VortexResult<()> {
43        for_all_matching_children(layout, field_mask, |mask, child| {
44            child.register_splits(&[mask], row_offset, splits)
45        })?;
46        Ok(())
47    }
48}
49
50fn for_all_matching_children<F>(
51    layout: &Layout,
52    field_mask: &[FieldMask],
53    mut per_child: F,
54) -> VortexResult<()>
55where
56    F: FnMut(FieldMask, Layout) -> VortexResult<()>,
57{
58    let DType::Struct(dtype, _) = layout.dtype() else {
59        vortex_bail!("Mismatched dtype {} for struct layout", layout.dtype());
60    };
61
62    // If the field mask contains an `All` fields, then enumerate all fields.
63    if field_mask.iter().any(|mask| mask.matches_all()) {
64        for (idx, field_dtype) in dtype.fields().enumerate() {
65            let child = layout.child(idx, field_dtype, dtype.field_name(idx)?)?;
66            per_child(FieldMask::All, child)?;
67        }
68        return Ok(());
69    }
70
71    // Enumerate each field in the mask
72    for path in field_mask {
73        let Some(field) = path.starting_field()? else {
74            // skip fields not in mask
75            continue;
76        };
77        let Field::Name(field_name) = field else {
78            vortex_bail!("Expected field name, got {:?}", field);
79        };
80
81        let idx = dtype.find(field_name)?;
82        let child = layout.child(idx, dtype.field_by_index(idx)?, field_name)?;
83        per_child(path.clone().step_into()?, child)?;
84    }
85
86    Ok(())
87}