vortex_array/arrays/struct_/vtable/
pipeline.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::hash::{Hash, Hasher};
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use futures::future::try_join_all;
10use vortex_dtype::DType;
11use vortex_error::{VortexExpect, VortexResult, vortex_err};
12
13use crate::arrays::{StructArray, StructVTable};
14use crate::operator::getitem::GetItemOperator;
15use crate::operator::{
16    BatchBindCtx, BatchExecution, BatchExecutionRef, BatchOperator, LengthBounds, Operator,
17    OperatorEq, OperatorHash, OperatorId, OperatorRef,
18};
19use crate::validity::Validity;
20use crate::vtable::PipelineVTable;
21use crate::{Array, ArrayRef, Canonical, IntoArray};
22
23impl PipelineVTable<StructVTable> for StructVTable {
24    fn to_operator(array: &StructArray) -> VortexResult<Option<OperatorRef>> {
25        let mut children = Vec::with_capacity(array.fields.len());
26        for field in array.fields().iter() {
27            if let Some(operator) = field.to_operator()? {
28                children.push(operator);
29            } else {
30                // If any of the children can't be converted, bail out.
31                return Ok(None);
32            }
33        }
34
35        Ok(Some(Arc::new(StructOperator {
36            dtype: array.dtype().clone(),
37            bounds: array.len().into(),
38            children,
39            // validity: array.validity.clone(),
40        })))
41    }
42}
43
44/// An operator for a struct array.
45#[derive(Debug)]
46struct StructOperator {
47    dtype: DType,
48    children: Vec<OperatorRef>,
49    bounds: LengthBounds,
50    // FIXME(ngates): validity should be an operator too...
51    // validity: Validity,
52}
53
54impl OperatorHash for StructOperator {
55    fn operator_hash<H: Hasher>(&self, state: &mut H) {
56        self.dtype.hash(state);
57        self.bounds.hash(state);
58        for child in &self.children {
59            child.operator_hash(state);
60        }
61    }
62}
63
64impl OperatorEq for StructOperator {
65    fn operator_eq(&self, other: &Self) -> bool {
66        self.dtype == other.dtype
67            && self.bounds == other.bounds
68            && self.children.len() == other.children.len()
69            && self
70                .children
71                .iter()
72                .zip(other.children.iter())
73                .all(|(a, b)| a.operator_eq(b))
74    }
75}
76
77impl Operator for StructOperator {
78    fn id(&self) -> OperatorId {
79        OperatorId::from("vortex.struct")
80    }
81
82    fn as_any(&self) -> &dyn Any {
83        self
84    }
85
86    fn dtype(&self) -> &DType {
87        &self.dtype
88    }
89
90    fn bounds(&self) -> LengthBounds {
91        self.bounds
92    }
93
94    fn children(&self) -> &[OperatorRef] {
95        &self.children
96    }
97
98    fn with_children(self: Arc<Self>, children: Vec<OperatorRef>) -> VortexResult<OperatorRef> {
99        let bounds = LengthBounds::intersect_all(children.iter().map(|c| c.bounds()));
100        Ok(Arc::new(StructOperator {
101            dtype: self.dtype.clone(),
102            // validity: self.validity.clone(),
103            children,
104            bounds,
105        }))
106    }
107
108    fn reduce_parent(
109        &self,
110        parent: OperatorRef,
111        _child_idx: usize,
112    ) -> VortexResult<Option<OperatorRef>> {
113        // The only real things we know how to push-down are things that exclusively operate on
114        // validity, or operate on a single field.
115        if let Some(getitem) = parent.as_any().downcast_ref::<GetItemOperator>() {
116            let field_idx = self
117                .dtype
118                .as_struct_fields_opt()
119                .vortex_expect("Struct dtype must have fields")
120                .find(getitem.field_name())
121                .ok_or_else(|| {
122                    vortex_err!(
123                        "Field {} not found in struct {}",
124                        getitem.field_name(),
125                        &self.dtype
126                    )
127                })?;
128
129            // FIXME(ngates): intersect validity
130            return Ok(Some(self.children[field_idx].clone()));
131        }
132
133        Ok(None)
134    }
135}
136
137impl BatchOperator for StructOperator {
138    fn bind(&self, ctx: &mut dyn BatchBindCtx) -> VortexResult<BatchExecutionRef> {
139        let children = (0..self.children.len())
140            .map(|i| ctx.child(i))
141            .collect::<VortexResult<Vec<_>>>()?;
142
143        // TODO(ngates): we need custom push down logic for selection over a struct array in case
144        //  there are no children. Because in this case, we need to hold onto the selection mask
145        //  to know the true length.
146
147        Ok(Box::new(StructExecution {
148            len: self
149                .bounds
150                .maybe_len()
151                .ok_or_else(|| vortex_err!("StructOperator must have a known length"))?,
152            dtype: self.dtype.clone(),
153            children,
154            // validity: self.validity.clone(),
155        }))
156    }
157}
158
159struct StructExecution {
160    len: usize,
161    dtype: DType,
162    children: Vec<BatchExecutionRef>,
163    // validity: Validity,
164}
165
166#[async_trait]
167impl BatchExecution for StructExecution {
168    async fn execute(self: Box<Self>) -> VortexResult<Canonical> {
169        let children: Vec<_> =
170            try_join_all(self.children.into_iter().map(|child| child.execute())).await?;
171        let children: Vec<ArrayRef> = children
172            .into_iter()
173            .map(|canonical| canonical.into_array())
174            .collect();
175
176        let array = StructArray::new(
177            self.dtype
178                .as_struct_fields_opt()
179                .vortex_expect("Struct dtype must have fields")
180                .names()
181                .clone(),
182            children,
183            self.len,
184            // self.validity,
185            Validity::AllValid,
186        );
187
188        Ok(Canonical::Struct(array))
189    }
190}