vortex_array/pipeline/driver/
mod.rs1pub mod allocation;
5mod bind;
6mod input;
7mod toposort;
8
9use std::hash::{BuildHasher, Hash, Hasher};
10
11use itertools::Itertools;
12use vortex_dtype::DType;
13use vortex_error::{VortexResult, vortex_ensure};
14use vortex_mask::Mask;
15use vortex_utils::aliases::hash_map::{HashMap, RandomState};
16use vortex_vector::{Vector, VectorMut, VectorMutOps};
17
18use crate::pipeline::driver::allocation::{OutputTarget, allocate_vectors};
19use crate::pipeline::driver::bind::bind_kernels;
20use crate::pipeline::driver::toposort::topological_sort;
21use crate::pipeline::{BitView, Kernel, KernelCtx, N, PipelineInputs};
22use crate::{Array, ArrayEq, ArrayHash, ArrayOperator, ArrayRef, ArrayVisitor, Precision};
23
24#[derive(Clone, Debug)]
35pub(crate) struct PipelineDriver {
36 dag: Vec<Node>,
38 root: NodeId,
39
40 batch_inputs: Vec<ArrayRef>,
42}
43
44type NodeId = usize;
45type BatchId = usize;
46
47#[derive(Debug, Clone)]
48struct Node {
49 array: ArrayRef,
51 #[allow(dead_code)] kind: NodeKind,
54 children: Vec<NodeId>,
56 parents: Vec<NodeId>,
58 batch_inputs: Vec<BatchId>,
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq)]
63enum NodeKind {
64 Input,
66 Source,
69 Transform,
71}
72
73impl PipelineDriver {
74 pub fn new(array: ArrayRef) -> PipelineDriver {
79 fn visit_node(
80 array: ArrayRef,
81 dag: &mut Vec<Node>,
82 batch: &mut Vec<ArrayRef>,
83 hash_to_id: &mut HashMap<u64, NodeId>,
84 random_state: &RandomState,
85 ) -> NodeId {
86 let subtree_hash = random_state.hash_one(ArrayKey(array.clone()));
88
89 if let Some(&existing_index) = hash_to_id.get(&subtree_hash) {
91 return existing_index;
93 }
94
95 let node = match array.as_pipelined() {
96 None => {
97 let batch_id = batch.len();
99 batch.push(array.clone());
100
101 Node {
102 array,
103 kind: NodeKind::Input,
104 children: vec![],
105 parents: vec![],
106 batch_inputs: vec![batch_id],
107 }
108 }
109 Some(pipelined) => match pipelined.inputs() {
110 PipelineInputs::Source => {
111 let children = array.children();
113 let mut batch_inputs = Vec::with_capacity(children.len());
114 for child in children {
115 batch_inputs.push(batch.len());
116 batch.push(child);
117 }
118
119 Node {
120 array,
121 kind: NodeKind::Source,
122 children: vec![],
123 parents: vec![],
124 batch_inputs,
125 }
126 }
127 PipelineInputs::Transform { pipelined_inputs } => {
128 let children = array.children();
130 let mut batch_inputs = Vec::with_capacity(children.len());
131 let mut pipeline_inputs = Vec::with_capacity(1);
132
133 for (child_idx, child) in children.into_iter().enumerate() {
134 if pipelined_inputs.contains(&child_idx) {
135 pipeline_inputs.push(visit_node(
136 child.clone(),
137 dag,
138 batch,
139 hash_to_id,
140 random_state,
141 ));
142 } else {
143 let batch_id = batch.len();
144 batch.push(child);
145 batch_inputs.push(batch_id);
146 }
147 }
148
149 Node {
150 array,
151 kind: NodeKind::Transform,
152 children: pipeline_inputs,
153 parents: vec![],
154 batch_inputs,
155 }
156 }
157 },
158 };
159
160 let node_id = dag.len();
161 dag.push(node);
162 hash_to_id.insert(subtree_hash, node_id);
163
164 node_id
165 }
166
167 let mut dag = vec![];
169 let mut batch = vec![];
170 let mut hash_to_id: HashMap<u64, NodeId> = HashMap::new();
171 let random_state = RandomState::default();
172 let root_index = visit_node(array, &mut dag, &mut batch, &mut hash_to_id, &random_state);
173
174 for i in 0..dag.len() {
176 let children = dag[i].children.clone();
177 for &child_idx in &children {
178 dag[child_idx].parents.push(i);
179 }
180 }
181
182 PipelineDriver {
183 root: root_index,
184 dag,
185 batch_inputs: batch,
186 }
187 }
188
189 fn root_array(&self) -> &ArrayRef {
190 &self.dag[self.root].array
191 }
192
193 pub fn execute(self, selection: &Mask) -> VortexResult<Vector> {
195 let dtype = self.root_array().dtype().clone();
196
197 let batch_inputs: Vec<_> = self
199 .batch_inputs
200 .into_iter()
201 .map(|array| array.execute().map(Some))
202 .try_collect()?;
203
204 let exec_order = topological_sort(&self.dag)?;
206
207 let allocation_plan = allocate_vectors(&self.dag, &exec_order)?;
209
210 let kernels = bind_kernels(self.dag, &allocation_plan, batch_inputs)?;
212
213 let ctx = KernelCtx::new(allocation_plan.vectors);
215
216 Pipeline {
217 dtype,
218 ctx,
219 kernels,
220 exec_order,
221 output_targets: allocation_plan.output_targets,
222 }
223 .execute(selection)
224 }
225}
226
227struct Pipeline {
228 dtype: DType,
229 ctx: KernelCtx,
230 kernels: Vec<Box<dyn Kernel>>,
231 exec_order: Vec<NodeId>,
232 output_targets: Vec<OutputTarget>,
233}
234
235impl Pipeline {
236 fn execute(&mut self, selection: &Mask) -> VortexResult<Vector> {
237 let capacity = selection.true_count().next_multiple_of(N);
239 let mut output = VectorMut::with_capacity(&self.dtype, capacity);
240
241 match selection {
242 Mask::AllFalse(_) => {}
243 Mask::AllTrue(_) => {
244 let nchunks = selection.len() / N;
247 for _ in 0..nchunks {
248 self.step(&BitView::all_true(), &mut output)?;
249 }
250
251 let remaining = selection.len() % N;
253 if remaining > 0 {
254 let selection_view = BitView::with_prefix(remaining);
255 self.step(&selection_view, &mut output)?;
256 }
257 }
258 Mask::Values(mask_values) => {
259 let selection_bits = mask_values.bit_buffer();
261 for selection_view in selection_bits.iter_bit_views() {
262 self.step(&selection_view, &mut output)?;
263 }
264 }
265 }
266
267 Ok(output.freeze())
268 }
269
270 fn step(&mut self, selection: &BitView, output: &mut VectorMut) -> VortexResult<()> {
272 for &node_idx in self.exec_order.iter() {
274 let kernel = &mut self.kernels[node_idx];
275
276 match &self.output_targets[node_idx] {
279 OutputTarget::ExternalOutput => {
280 let mut tail = output.split_off(output.len());
282 assert!(tail.is_empty());
283
284 kernel.step(&self.ctx, selection, &mut tail)?;
285
286 let len = tail.len();
287 vortex_ensure!(
288 len == N || len == selection.true_count(),
289 "Kernel produced incorrect number of output elements, \
290 expected either {N} or {}, got {len}",
291 selection.true_count(),
292 );
293
294 if selection.true_count() < N && len == N {
298 todo!("Filter via a bit mask")
300 }
301
302 output.unsplit(tail);
304 }
305 OutputTarget::IntermediateVector(vector_id) => {
306 let mut out_vector = self.ctx.take_output(vector_id);
307 out_vector.clear();
308 debug_assert!(out_vector.is_empty());
309
310 kernel.step(&self.ctx, selection, &mut out_vector)?;
311
312 let len = out_vector.len();
313 vortex_ensure!(
314 len == N || len == selection.true_count(),
315 "Kernel produced incorrect number of output elements, \
316 expected either {N} or {}, got {len}",
317 selection.true_count(),
318 );
319
320 self.ctx.replace_output(vector_id, out_vector);
322 }
323 };
324 }
325
326 Ok(())
327 }
328}
329
330struct ArrayKey(ArrayRef);
332impl Hash for ArrayKey {
333 fn hash<H: Hasher>(&self, mut state: &mut H) {
334 self.0.array_hash(&mut state, Precision::Ptr)
335 }
336}
337impl PartialEq for ArrayKey {
338 fn eq(&self, other: &Self) -> bool {
339 self.0.array_eq(&other.0, Precision::Ptr)
340 }
341}
342impl Eq for ArrayKey {}