vortex_array/arrays/struct_/vtable/
pipeline.rs1use std::any::Any;
5use std::hash::{Hash, Hasher};
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use futures::future::try_join_all;
10use vortex_dtype::DType;
11use vortex_error::{VortexExpect, VortexResult, vortex_err};
12
13use crate::arrays::{StructArray, StructVTable};
14use crate::operator::getitem::GetItemOperator;
15use crate::operator::{
16 BatchBindCtx, BatchExecution, BatchExecutionRef, BatchOperator, LengthBounds, Operator,
17 OperatorEq, OperatorHash, OperatorId, OperatorRef,
18};
19use crate::validity::Validity;
20use crate::vtable::PipelineVTable;
21use crate::{Array, ArrayRef, Canonical, IntoArray};
22
23impl PipelineVTable<StructVTable> for StructVTable {
24 fn to_operator(array: &StructArray) -> VortexResult<Option<OperatorRef>> {
25 let mut children = Vec::with_capacity(array.fields.len());
26 for field in array.fields().iter() {
27 if let Some(operator) = field.to_operator()? {
28 children.push(operator);
29 } else {
30 return Ok(None);
32 }
33 }
34
35 Ok(Some(Arc::new(StructOperator {
36 dtype: array.dtype().clone(),
37 bounds: array.len().into(),
38 children,
39 })))
41 }
42}
43
44#[derive(Debug)]
46struct StructOperator {
47 dtype: DType,
48 children: Vec<OperatorRef>,
49 bounds: LengthBounds,
50 }
53
54impl OperatorHash for StructOperator {
55 fn operator_hash<H: Hasher>(&self, state: &mut H) {
56 self.dtype.hash(state);
57 self.bounds.hash(state);
58 for child in &self.children {
59 child.operator_hash(state);
60 }
61 }
62}
63
64impl OperatorEq for StructOperator {
65 fn operator_eq(&self, other: &Self) -> bool {
66 self.dtype == other.dtype
67 && self.bounds == other.bounds
68 && self.children.len() == other.children.len()
69 && self
70 .children
71 .iter()
72 .zip(other.children.iter())
73 .all(|(a, b)| a.operator_eq(b))
74 }
75}
76
77impl Operator for StructOperator {
78 fn id(&self) -> OperatorId {
79 OperatorId::from("vortex.struct")
80 }
81
82 fn as_any(&self) -> &dyn Any {
83 self
84 }
85
86 fn dtype(&self) -> &DType {
87 &self.dtype
88 }
89
90 fn bounds(&self) -> LengthBounds {
91 self.bounds
92 }
93
94 fn children(&self) -> &[OperatorRef] {
95 &self.children
96 }
97
98 fn with_children(self: Arc<Self>, children: Vec<OperatorRef>) -> VortexResult<OperatorRef> {
99 let bounds = LengthBounds::intersect_all(children.iter().map(|c| c.bounds()));
100 Ok(Arc::new(StructOperator {
101 dtype: self.dtype.clone(),
102 children,
104 bounds,
105 }))
106 }
107
108 fn reduce_parent(
109 &self,
110 parent: OperatorRef,
111 _child_idx: usize,
112 ) -> VortexResult<Option<OperatorRef>> {
113 if let Some(getitem) = parent.as_any().downcast_ref::<GetItemOperator>() {
116 let field_idx = self
117 .dtype
118 .as_struct_fields_opt()
119 .vortex_expect("Struct dtype must have fields")
120 .find(getitem.field_name())
121 .ok_or_else(|| {
122 vortex_err!(
123 "Field {} not found in struct {}",
124 getitem.field_name(),
125 &self.dtype
126 )
127 })?;
128
129 return Ok(Some(self.children[field_idx].clone()));
131 }
132
133 Ok(None)
134 }
135}
136
137impl BatchOperator for StructOperator {
138 fn bind(&self, ctx: &mut dyn BatchBindCtx) -> VortexResult<BatchExecutionRef> {
139 let children = (0..self.children.len())
140 .map(|i| ctx.child(i))
141 .collect::<VortexResult<Vec<_>>>()?;
142
143 Ok(Box::new(StructExecution {
148 len: self
149 .bounds
150 .maybe_len()
151 .ok_or_else(|| vortex_err!("StructOperator must have a known length"))?,
152 dtype: self.dtype.clone(),
153 children,
154 }))
156 }
157}
158
159struct StructExecution {
160 len: usize,
161 dtype: DType,
162 children: Vec<BatchExecutionRef>,
163 }
165
166#[async_trait]
167impl BatchExecution for StructExecution {
168 async fn execute(self: Box<Self>) -> VortexResult<Canonical> {
169 let children: Vec<_> =
170 try_join_all(self.children.into_iter().map(|child| child.execute())).await?;
171 let children: Vec<ArrayRef> = children
172 .into_iter()
173 .map(|canonical| canonical.into_array())
174 .collect();
175
176 let array = StructArray::new(
177 self.dtype
178 .as_struct_fields_opt()
179 .vortex_expect("Struct dtype must have fields")
180 .names()
181 .clone(),
182 children,
183 self.len,
184 Validity::AllValid,
186 );
187
188 Ok(Canonical::Struct(array))
189 }
190}