vortex_array/arrays/struct_/
operator.rs1use std::any::Any;
5use std::hash::{Hash, Hasher};
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use futures::future::try_join_all;
10use vortex_dtype::DType;
11use vortex_error::{VortexExpect, VortexResult, vortex_err};
12
13use crate::arrays::{StructArray, StructVTable};
14use crate::operator::getitem::GetItemOperator;
15use crate::operator::{
16 BatchBindCtx, BatchExecution, BatchExecutionRef, BatchOperator, Operator, OperatorEq,
17 OperatorHash, OperatorId, OperatorRef,
18};
19use crate::validity::Validity;
20use crate::vtable::PipelineVTable;
21use crate::{Array, Canonical, IntoArray};
22
23impl PipelineVTable<StructVTable> for StructVTable {
24 fn to_operator(array: &StructArray) -> VortexResult<Option<OperatorRef>> {
25 let mut children = Vec::with_capacity(array.fields.len());
26 for field in array.fields() {
27 if let Some(operator) = field.to_operator()? {
28 children.push(operator);
29 } else {
30 return Ok(None);
32 }
33 }
34
35 Ok(Some(Arc::new(StructOperator {
36 dtype: array.dtype().clone(),
37 len: array.len(),
38 children,
39 })))
41 }
42}
43
44#[derive(Debug)]
46struct StructOperator {
47 dtype: DType,
48 len: usize,
49 children: Vec<OperatorRef>,
50 }
53
54impl OperatorHash for StructOperator {
55 fn operator_hash<H: Hasher>(&self, state: &mut H) {
56 self.dtype.hash(state);
57 self.len.hash(state);
58 for child in &self.children {
59 child.operator_hash(state);
60 }
61 }
62}
63
64impl OperatorEq for StructOperator {
65 fn operator_eq(&self, other: &Self) -> bool {
66 self.dtype == other.dtype
67 && self.len == other.len
68 && self.children.len() == other.children.len()
69 && self
70 .children
71 .iter()
72 .zip(other.children.iter())
73 .all(|(a, b)| a.operator_eq(b))
74 }
75}
76
77impl Operator for StructOperator {
78 fn id(&self) -> OperatorId {
79 OperatorId::from("vortex.struct")
80 }
81
82 fn as_any(&self) -> &dyn Any {
83 self
84 }
85
86 fn dtype(&self) -> &DType {
87 &self.dtype
88 }
89
90 fn len(&self) -> usize {
91 self.len
92 }
93
94 fn children(&self) -> &[OperatorRef] {
95 &self.children
96 }
97
98 fn with_children(self: Arc<Self>, children: Vec<OperatorRef>) -> VortexResult<OperatorRef> {
99 Ok(Arc::new(StructOperator {
100 len: self.len,
101 dtype: self.dtype.clone(),
102 children,
104 }))
105 }
106
107 fn reduce_parent(
108 &self,
109 parent: OperatorRef,
110 _child_idx: usize,
111 ) -> VortexResult<Option<OperatorRef>> {
112 if let Some(getitem) = parent.as_any().downcast_ref::<GetItemOperator>() {
115 let field_idx = self
116 .dtype
117 .as_struct_fields_opt()
118 .vortex_expect("Struct dtype must have fields")
119 .find(getitem.field_name())
120 .ok_or_else(|| {
121 vortex_err!(
122 "Field {} not found in struct {}",
123 getitem.field_name(),
124 &self.dtype
125 )
126 })?;
127
128 return Ok(Some(self.children[field_idx].clone()));
130 }
131
132 Ok(None)
133 }
134}
135
136impl BatchOperator for StructOperator {
137 fn bind(&self, ctx: &mut dyn BatchBindCtx) -> VortexResult<BatchExecutionRef> {
138 let children = (0..self.children.len())
139 .map(|i| ctx.child(i))
140 .collect::<VortexResult<Vec<_>>>()?;
141 Ok(Box::new(StructExecution {
142 len: self.len,
143 dtype: self.dtype.clone(),
144 children,
145 }))
147 }
148}
149
150struct StructExecution {
151 len: usize,
152 dtype: DType,
153 children: Vec<BatchExecutionRef>,
154 }
156
157#[async_trait]
158impl BatchExecution for StructExecution {
159 async fn execute(self: Box<Self>) -> VortexResult<Canonical> {
160 let children: Vec<_> =
161 try_join_all(self.children.into_iter().map(|child| child.execute())).await?;
162 let children = children
163 .into_iter()
164 .map(|canonical| canonical.into_array())
165 .collect();
166
167 let array = StructArray::new(
168 self.dtype
169 .as_struct_fields_opt()
170 .vortex_expect("Struct dtype must have fields")
171 .names()
172 .clone(),
173 children,
174 self.len,
175 Validity::AllValid,
177 );
178
179 Ok(Canonical::Struct(array))
180 }
181}