vortex_array/arrays/struct_/
operator.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright the Vortex contributors
3
4use std::any::Any;
5use std::hash::{Hash, Hasher};
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use futures::future::try_join_all;
10use vortex_dtype::DType;
11use vortex_error::{VortexExpect, VortexResult, vortex_err};
12
13use crate::arrays::{StructArray, StructVTable};
14use crate::operator::getitem::GetItemOperator;
15use crate::operator::{
16    BatchBindCtx, BatchExecution, BatchExecutionRef, BatchOperator, Operator, OperatorEq,
17    OperatorHash, OperatorId, OperatorRef,
18};
19use crate::validity::Validity;
20use crate::vtable::PipelineVTable;
21use crate::{Array, Canonical, IntoArray};
22
23impl PipelineVTable<StructVTable> for StructVTable {
24    fn to_operator(array: &StructArray) -> VortexResult<Option<OperatorRef>> {
25        let mut children = Vec::with_capacity(array.fields.len());
26        for field in array.fields() {
27            if let Some(operator) = field.to_operator()? {
28                children.push(operator);
29            } else {
30                // If any of the children can't be converted, bail out.
31                return Ok(None);
32            }
33        }
34
35        Ok(Some(Arc::new(StructOperator {
36            dtype: array.dtype().clone(),
37            len: array.len(),
38            children,
39            // validity: array.validity.clone(),
40        })))
41    }
42}
43
44/// An operator for a struct array.
45#[derive(Debug)]
46struct StructOperator {
47    dtype: DType,
48    len: usize,
49    children: Vec<OperatorRef>,
50    // FIXME(ngates): validity should be an operator too...
51    // validity: Validity,
52}
53
54impl OperatorHash for StructOperator {
55    fn operator_hash<H: Hasher>(&self, state: &mut H) {
56        self.dtype.hash(state);
57        self.len.hash(state);
58        for child in &self.children {
59            child.operator_hash(state);
60        }
61    }
62}
63
64impl OperatorEq for StructOperator {
65    fn operator_eq(&self, other: &Self) -> bool {
66        self.dtype == other.dtype
67            && self.len == other.len
68            && self.children.len() == other.children.len()
69            && self
70                .children
71                .iter()
72                .zip(other.children.iter())
73                .all(|(a, b)| a.operator_eq(b))
74    }
75}
76
77impl Operator for StructOperator {
78    fn id(&self) -> OperatorId {
79        OperatorId::from("vortex.struct")
80    }
81
82    fn as_any(&self) -> &dyn Any {
83        self
84    }
85
86    fn dtype(&self) -> &DType {
87        &self.dtype
88    }
89
90    fn len(&self) -> usize {
91        self.len
92    }
93
94    fn children(&self) -> &[OperatorRef] {
95        &self.children
96    }
97
98    fn with_children(self: Arc<Self>, children: Vec<OperatorRef>) -> VortexResult<OperatorRef> {
99        Ok(Arc::new(StructOperator {
100            len: self.len,
101            dtype: self.dtype.clone(),
102            // validity: self.validity.clone(),
103            children,
104        }))
105    }
106
107    fn reduce_parent(
108        &self,
109        parent: OperatorRef,
110        _child_idx: usize,
111    ) -> VortexResult<Option<OperatorRef>> {
112        // The only real things we know how to push-down are things that exclusively operate on
113        // validity, or operate on a single field.
114        if let Some(getitem) = parent.as_any().downcast_ref::<GetItemOperator>() {
115            let field_idx = self
116                .dtype
117                .as_struct_fields_opt()
118                .vortex_expect("Struct dtype must have fields")
119                .find(getitem.field_name())
120                .ok_or_else(|| {
121                    vortex_err!(
122                        "Field {} not found in struct {}",
123                        getitem.field_name(),
124                        &self.dtype
125                    )
126                })?;
127
128            // FIXME(ngates): intersect validity
129            return Ok(Some(self.children[field_idx].clone()));
130        }
131
132        Ok(None)
133    }
134}
135
136impl BatchOperator for StructOperator {
137    fn bind(&self, ctx: &mut dyn BatchBindCtx) -> VortexResult<BatchExecutionRef> {
138        let children = (0..self.children.len())
139            .map(|i| ctx.child(i))
140            .collect::<VortexResult<Vec<_>>>()?;
141        Ok(Box::new(StructExecution {
142            len: self.len,
143            dtype: self.dtype.clone(),
144            children,
145            // validity: self.validity.clone(),
146        }))
147    }
148}
149
150struct StructExecution {
151    len: usize,
152    dtype: DType,
153    children: Vec<BatchExecutionRef>,
154    // validity: Validity,
155}
156
157#[async_trait]
158impl BatchExecution for StructExecution {
159    async fn execute(self: Box<Self>) -> VortexResult<Canonical> {
160        let children: Vec<_> =
161            try_join_all(self.children.into_iter().map(|child| child.execute())).await?;
162        let children = children
163            .into_iter()
164            .map(|canonical| canonical.into_array())
165            .collect();
166
167        let array = StructArray::new(
168            self.dtype
169                .as_struct_fields_opt()
170                .vortex_expect("Struct dtype must have fields")
171                .names()
172                .clone(),
173            children,
174            self.len,
175            // self.validity,
176            Validity::AllValid,
177        );
178
179        Ok(Canonical::Struct(array))
180    }
181}