1use std::any::Any;
51use std::collections::{HashMap, HashSet};
52use std::fmt;
53use std::pin::Pin;
54use std::sync::Arc;
55use std::task::{Context, Poll};
56
57use arrow_array::builder::UInt32Builder;
58use arrow_array::{Array, ArrayRef, RecordBatch, UInt64Array};
59use arrow_schema::{Field, Schema, SchemaRef};
60use datafusion::common::{Result as DFResult, ScalarValue};
61use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
62use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
63use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
64use futures::{Stream, TryStreamExt};
65
66use super::common::compute_plan_properties;
67use super::scan::GraphScanExec;
68
69pub(crate) const MAX_VIDS_PER_CHUNK: usize = 10_000;
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76pub enum ProbeSide {
77 Left,
78 Right,
79}
80
81#[derive(Debug, Clone, Copy)]
85pub struct EquiPair {
86 pub left_col_idx: usize,
87 pub right_col_idx: usize,
88}
89
90impl EquiPair {
91 fn build_col(&self, probe_side: ProbeSide) -> usize {
93 match probe_side {
94 ProbeSide::Left => self.right_col_idx,
95 ProbeSide::Right => self.left_col_idx,
96 }
97 }
98
99 fn probe_col(&self, probe_side: ProbeSide) -> usize {
101 match probe_side {
102 ProbeSide::Left => self.left_col_idx,
103 ProbeSide::Right => self.right_col_idx,
104 }
105 }
106}
107
108#[derive(Debug, Clone, Copy, PartialEq, Eq)]
110pub enum VidJoinKind {
111 Inner,
112 Left,
113}
114
115pub struct VidLookupJoinExec {
117 left: Arc<dyn ExecutionPlan>,
118 right: Arc<dyn ExecutionPlan>,
119 probe_side: ProbeSide,
120 pairs: Vec<EquiPair>,
123 join_kind: VidJoinKind,
124 output_schema: SchemaRef,
126 properties: Arc<PlanProperties>,
127 metrics: ExecutionPlanMetricsSet,
128}
129
130impl fmt::Debug for VidLookupJoinExec {
131 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132 f.debug_struct("VidLookupJoinExec")
133 .field("probe_side", &self.probe_side)
134 .field("pairs", &self.pairs.len())
135 .field("join_kind", &self.join_kind)
136 .finish()
137 }
138}
139
140impl VidLookupJoinExec {
141 pub fn try_new(
150 left: Arc<dyn ExecutionPlan>,
151 right: Arc<dyn ExecutionPlan>,
152 probe_side: ProbeSide,
153 pairs: Vec<EquiPair>,
154 join_kind: VidJoinKind,
155 ) -> DFResult<Self> {
156 if pairs.is_empty() {
157 return Err(datafusion::error::DataFusionError::Plan(
158 "VidLookupJoinExec: pairs must be non-empty".into(),
159 ));
160 }
161 let probe_plan = match probe_side {
162 ProbeSide::Left => &left,
163 ProbeSide::Right => &right,
164 };
165 if probe_plan
166 .as_any()
167 .downcast_ref::<GraphScanExec>()
168 .is_none()
169 {
170 return Err(datafusion::error::DataFusionError::Plan(
171 "VidLookupJoinExec: probe-side child must be a GraphScanExec".into(),
172 ));
173 }
174 let output_schema = concat_schemas(&left.schema(), &right.schema());
175 let properties = compute_plan_properties(output_schema.clone());
176 Ok(Self {
177 left,
178 right,
179 probe_side,
180 pairs,
181 join_kind,
182 output_schema,
183 properties,
184 metrics: ExecutionPlanMetricsSet::new(),
185 })
186 }
187
188 fn build_child(&self) -> &Arc<dyn ExecutionPlan> {
189 match self.probe_side {
190 ProbeSide::Left => &self.right,
191 ProbeSide::Right => &self.left,
192 }
193 }
194
195 fn probe_child(&self) -> &Arc<dyn ExecutionPlan> {
196 match self.probe_side {
197 ProbeSide::Left => &self.left,
198 ProbeSide::Right => &self.right,
199 }
200 }
201}
202
203impl DisplayAs for VidLookupJoinExec {
204 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
205 write!(
206 f,
207 "VidLookupJoinExec: probe={:?}, pairs={}, kind={:?}",
208 self.probe_side,
209 self.pairs.len(),
210 self.join_kind
211 )
212 }
213}
214
215impl ExecutionPlan for VidLookupJoinExec {
216 fn name(&self) -> &str {
217 "VidLookupJoinExec"
218 }
219
220 fn as_any(&self) -> &dyn Any {
221 self
222 }
223
224 fn schema(&self) -> SchemaRef {
225 self.output_schema.clone()
226 }
227
228 fn properties(&self) -> &Arc<PlanProperties> {
229 &self.properties
230 }
231
232 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
233 vec![self.build_child()]
238 }
239
240 fn with_new_children(
241 self: Arc<Self>,
242 children: Vec<Arc<dyn ExecutionPlan>>,
243 ) -> DFResult<Arc<dyn ExecutionPlan>> {
244 if children.len() != 1 {
245 return Err(datafusion::error::DataFusionError::Plan(format!(
246 "VidLookupJoinExec expects exactly one child (the build side); got {}",
247 children.len()
248 )));
249 }
250 let new_build = children.into_iter().next().unwrap();
251 let (new_left, new_right) = match self.probe_side {
252 ProbeSide::Left => (self.left.clone(), new_build),
253 ProbeSide::Right => (new_build, self.right.clone()),
254 };
255 Ok(Arc::new(Self::try_new(
256 new_left,
257 new_right,
258 self.probe_side,
259 self.pairs.clone(),
260 self.join_kind,
261 )?))
262 }
263
264 fn execute(
265 &self,
266 partition: usize,
267 context: Arc<TaskContext>,
268 ) -> DFResult<SendableRecordBatchStream> {
269 let metrics = BaselineMetrics::new(&self.metrics, partition);
270 let build = self.build_child().clone();
271 let probe = self.probe_child().clone();
272 let probe_side = self.probe_side;
273 let pairs = self.pairs.clone();
274 let join_kind = self.join_kind;
275 let output_schema = self.output_schema.clone();
276 let left_schema = self.left.schema();
277 let right_schema = self.right.schema();
278
279 let fut = async move {
280 run_join(
281 build,
282 probe,
283 probe_side,
284 pairs,
285 join_kind,
286 left_schema,
287 right_schema,
288 output_schema.clone(),
289 partition,
290 context,
291 )
292 .await
293 };
294
295 Ok(Box::pin(VidLookupJoinStream {
296 state: VidLookupJoinStreamState::Running(Box::pin(fut)),
297 schema: self.output_schema.clone(),
298 metrics,
299 }))
300 }
301
302 fn metrics(&self) -> Option<MetricsSet> {
303 Some(self.metrics.clone_inner())
304 }
305}
306
307enum VidLookupJoinStreamState {
312 Running(Pin<Box<dyn std::future::Future<Output = DFResult<RecordBatch>> + Send>>),
313 Done,
314}
315
316struct VidLookupJoinStream {
317 state: VidLookupJoinStreamState,
318 schema: SchemaRef,
319 metrics: BaselineMetrics,
320}
321
322impl Stream for VidLookupJoinStream {
323 type Item = DFResult<RecordBatch>;
324
325 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
326 let metrics = self.metrics.clone();
327 let _timer = metrics.elapsed_compute().timer();
328 match &mut self.state {
329 VidLookupJoinStreamState::Running(fut) => match fut.as_mut().poll(cx) {
330 Poll::Ready(Ok(batch)) => {
331 self.metrics.record_output(batch.num_rows());
332 self.state = VidLookupJoinStreamState::Done;
333 Poll::Ready(Some(Ok(batch)))
334 }
335 Poll::Ready(Err(e)) => {
336 self.state = VidLookupJoinStreamState::Done;
337 Poll::Ready(Some(Err(e)))
338 }
339 Poll::Pending => Poll::Pending,
340 },
341 VidLookupJoinStreamState::Done => Poll::Ready(None),
342 }
343 }
344}
345
346impl RecordBatchStream for VidLookupJoinStream {
347 fn schema(&self) -> SchemaRef {
348 self.schema.clone()
349 }
350}
351
352#[allow(clippy::too_many_arguments)]
357async fn run_join(
358 build: Arc<dyn ExecutionPlan>,
359 probe: Arc<dyn ExecutionPlan>,
360 probe_side: ProbeSide,
361 pairs: Vec<EquiPair>,
362 join_kind: VidJoinKind,
363 left_schema: SchemaRef,
364 right_schema: SchemaRef,
365 output_schema: SchemaRef,
366 partition: usize,
367 context: Arc<TaskContext>,
368) -> DFResult<RecordBatch> {
369 let build_stream = build.execute(partition, context)?;
371 let build_batches: Vec<RecordBatch> = build_stream.try_collect().await?;
372
373 if build_batches.is_empty() {
374 return Ok(RecordBatch::new_empty(output_schema));
375 }
376
377 let anchor = pairs[0];
379 let build_anchor_col_idx = anchor.build_col(probe_side);
380 let mut vid_set: HashSet<u64> = HashSet::new();
381 for batch in &build_batches {
382 let arr = batch.column(build_anchor_col_idx);
383 let u64_arr = arr.as_any().downcast_ref::<UInt64Array>().ok_or_else(|| {
384 datafusion::error::DataFusionError::Plan(format!(
385 "VidLookupJoinExec: build anchor column at idx {} is not UInt64 (got {:?})",
386 build_anchor_col_idx,
387 arr.data_type()
388 ))
389 })?;
390 for i in 0..u64_arr.len() {
391 if !u64_arr.is_null(i) {
392 vid_set.insert(u64_arr.value(i));
393 }
394 }
395 }
396
397 let probe_scan = probe
402 .as_any()
403 .downcast_ref::<GraphScanExec>()
404 .expect("planner ensured probe is GraphScanExec");
405 let probe_batch = if vid_set.is_empty() {
406 RecordBatch::new_empty(probe_scan.schema())
409 } else {
410 let vids: Vec<u64> = vid_set.iter().copied().collect();
411 let mut chunks: Vec<RecordBatch> = Vec::new();
412 for chunk in vids.chunks(MAX_VIDS_PER_CHUNK) {
413 let batch = probe_scan.execute_with_vid_filter(chunk).await?;
414 if batch.num_rows() > 0 {
415 chunks.push(batch);
416 }
417 }
418 if chunks.is_empty() {
419 RecordBatch::new_empty(probe_scan.schema())
420 } else if chunks.len() == 1 {
421 chunks.into_iter().next().unwrap()
422 } else {
423 let schema = chunks[0].schema();
424 arrow::compute::concat_batches(&schema, &chunks)
425 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?
426 }
427 };
428
429 let probe_vid_idx = locate_vid_column(&probe_batch.schema())?;
434 let probe_anchor_col_idx = anchor.probe_col(probe_side);
435 if probe_anchor_col_idx != probe_vid_idx {
438 return Err(datafusion::error::DataFusionError::Plan(format!(
439 "VidLookupJoinExec: anchor probe column idx {} != probe schema's _vid idx {} \
440 (planner pre-check should have aligned these)",
441 probe_anchor_col_idx, probe_vid_idx
442 )));
443 }
444 let probe_index = build_probe_vid_index(&probe_batch, probe_vid_idx)?;
445
446 let n_non_anchor = pairs.len() - 1;
452 let mut matches: Vec<JoinMatch> = Vec::new();
453 let mut unmatched: Vec<(usize, usize)> = Vec::new(); for (build_batch_idx, build_batch) in build_batches.iter().enumerate() {
456 let build_anchor_arr = build_batch
457 .column(build_anchor_col_idx)
458 .as_any()
459 .downcast_ref::<UInt64Array>()
460 .expect("validated above");
461 for build_row_idx in 0..build_anchor_arr.len() {
462 if build_anchor_arr.is_null(build_row_idx) {
463 if join_kind == VidJoinKind::Left {
464 unmatched.push((build_batch_idx, build_row_idx));
465 }
466 continue;
467 }
468 let key = build_anchor_arr.value(build_row_idx);
469 let Some(probe_rows) = probe_index.get(&key) else {
470 if join_kind == VidJoinKind::Left {
471 unmatched.push((build_batch_idx, build_row_idx));
472 }
473 continue;
474 };
475
476 let mut had_match_for_this_build_row = false;
478 for &probe_row_idx in probe_rows {
479 let mut all_match = true;
480 for pair in &pairs[1..1 + n_non_anchor] {
481 let build_col_idx = pair.build_col(probe_side);
482 let probe_col_idx = pair.probe_col(probe_side);
483 if !values_equal(
484 build_batch.column(build_col_idx),
485 build_row_idx,
486 probe_batch.column(probe_col_idx),
487 probe_row_idx,
488 )? {
489 all_match = false;
490 break;
491 }
492 }
493 if all_match {
494 matches.push(JoinMatch {
495 build_batch_idx,
496 build_row_idx,
497 probe_row_idx,
498 });
499 had_match_for_this_build_row = true;
500 }
501 }
502 if !had_match_for_this_build_row && join_kind == VidJoinKind::Left {
503 unmatched.push((build_batch_idx, build_row_idx));
504 }
505 }
506 }
507
508 emit_joined_batch(
510 &build_batches,
511 &probe_batch,
512 &matches,
513 &unmatched,
514 probe_side,
515 &left_schema,
516 &right_schema,
517 &output_schema,
518 )
519}
520
521#[derive(Clone, Copy)]
526struct JoinMatch {
527 build_batch_idx: usize,
528 build_row_idx: usize,
529 probe_row_idx: usize,
530}
531
532fn build_probe_vid_index(
533 probe_batch: &RecordBatch,
534 probe_vid_idx: usize,
535) -> DFResult<HashMap<u64, Vec<usize>>> {
536 let arr = probe_batch
537 .column(probe_vid_idx)
538 .as_any()
539 .downcast_ref::<UInt64Array>()
540 .ok_or_else(|| {
541 datafusion::error::DataFusionError::Plan(
542 "VidLookupJoinExec: probe `_vid` column is not UInt64".into(),
543 )
544 })?;
545 let mut index: HashMap<u64, Vec<usize>> = HashMap::with_capacity(arr.len());
546 for i in 0..arr.len() {
547 if !arr.is_null(i) {
548 index.entry(arr.value(i)).or_default().push(i);
549 }
550 }
551 Ok(index)
552}
553
554fn values_equal(a_col: &ArrayRef, a_row: usize, b_col: &ArrayRef, b_row: usize) -> DFResult<bool> {
559 let a = ScalarValue::try_from_array(a_col, a_row)?;
560 let b = ScalarValue::try_from_array(b_col, b_row)?;
561 Ok(a == b)
562}
563
564fn locate_vid_column(schema: &SchemaRef) -> DFResult<usize> {
566 schema
567 .fields()
568 .iter()
569 .enumerate()
570 .find_map(|(i, f)| {
571 if f.name() == "_vid" || f.name().ends_with("._vid") {
572 Some(i)
573 } else {
574 None
575 }
576 })
577 .ok_or_else(|| {
578 datafusion::error::DataFusionError::Plan(
579 "VidLookupJoinExec: probe schema has no _vid column".into(),
580 )
581 })
582}
583
584fn concat_schemas(left: &SchemaRef, right: &SchemaRef) -> SchemaRef {
587 let mut fields: Vec<Field> = Vec::with_capacity(left.fields().len() + right.fields().len());
588 for f in left.fields() {
589 fields.push(f.as_ref().clone());
590 }
591 for f in right.fields() {
592 fields.push(f.as_ref().clone());
593 }
594 Arc::new(Schema::new(fields))
595}
596
597#[allow(clippy::too_many_arguments)]
605fn emit_joined_batch(
606 build_batches: &[RecordBatch],
607 probe_batch: &RecordBatch,
608 matches: &[JoinMatch],
609 unmatched: &[(usize, usize)],
610 probe_side: ProbeSide,
611 left_schema: &SchemaRef,
612 right_schema: &SchemaRef,
613 output_schema: &SchemaRef,
614) -> DFResult<RecordBatch> {
615 let total_rows = matches.len() + unmatched.len();
616 if total_rows == 0 {
617 return Ok(RecordBatch::new_empty(output_schema.clone()));
618 }
619
620 let n_build_batches = build_batches.len();
623 let mut match_take_per_build_batch: Vec<Vec<u32>> =
624 (0..n_build_batches).map(|_| Vec::new()).collect();
625 let mut match_probe_take: Vec<u32> = Vec::with_capacity(matches.len());
626 for m in matches {
627 match_take_per_build_batch[m.build_batch_idx].push(m.build_row_idx as u32);
628 match_probe_take.push(m.probe_row_idx as u32);
629 }
630
631 let mut unmatched_take_per_build_batch: Vec<Vec<u32>> =
633 (0..n_build_batches).map(|_| Vec::new()).collect();
634 for &(bb_idx, br_idx) in unmatched {
635 unmatched_take_per_build_batch[bb_idx].push(br_idx as u32);
636 }
637
638 let n_build_cols = build_batches[0].num_columns();
641 let mut build_columns: Vec<ArrayRef> = Vec::with_capacity(n_build_cols);
642 for col_idx in 0..n_build_cols {
643 let mut chunks: Vec<ArrayRef> = Vec::new();
644 for batch_idx in 0..n_build_batches {
645 if !match_take_per_build_batch[batch_idx].is_empty() {
647 chunks.push(take_indices(
648 build_batches[batch_idx].column(col_idx),
649 &match_take_per_build_batch[batch_idx],
650 )?);
651 }
652 if !unmatched_take_per_build_batch[batch_idx].is_empty() {
654 chunks.push(take_indices(
655 build_batches[batch_idx].column(col_idx),
656 &unmatched_take_per_build_batch[batch_idx],
657 )?);
658 }
659 }
660 build_columns.push(concat_arrays(&chunks)?);
661 }
662
663 let n_probe_cols = probe_batch.num_columns();
666 let mut probe_columns: Vec<ArrayRef> = Vec::with_capacity(n_probe_cols);
667 let probe_match_arr = take_indices_u32_slice(&match_probe_take);
668 let n_unmatched = unmatched.len();
669 for col_idx in 0..n_probe_cols {
670 let probe_col = probe_batch.column(col_idx);
671 let matched_part = if match_probe_take.is_empty() {
672 arrow_array::new_empty_array(probe_col.data_type())
673 } else {
674 arrow::compute::take(probe_col.as_ref(), &probe_match_arr, None)
675 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?
676 };
677 if n_unmatched == 0 {
678 probe_columns.push(matched_part);
679 } else {
680 let null_part = arrow_array::new_null_array(probe_col.data_type(), n_unmatched);
681 probe_columns.push(concat_arrays(&[matched_part, null_part])?);
682 }
683 }
684
685 let (left_columns, right_columns) = match probe_side {
687 ProbeSide::Left => (probe_columns, build_columns),
688 ProbeSide::Right => (build_columns, probe_columns),
689 };
690
691 let _ = (left_schema, right_schema); let mut all_columns = left_columns;
694 all_columns.extend(right_columns);
695
696 RecordBatch::try_new(output_schema.clone(), all_columns)
697 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))
698}
699
700fn take_indices(col: &ArrayRef, indices: &[u32]) -> DFResult<ArrayRef> {
701 let take_array = take_indices_u32_slice(indices);
702 arrow::compute::take(col.as_ref(), &take_array, None)
703 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))
704}
705
706fn take_indices_u32_slice(indices: &[u32]) -> arrow_array::UInt32Array {
707 let mut b = UInt32Builder::with_capacity(indices.len());
708 for &i in indices {
709 b.append_value(i);
710 }
711 b.finish()
712}
713
714fn concat_arrays(arrays: &[ArrayRef]) -> DFResult<ArrayRef> {
715 if arrays.len() == 1 {
716 return Ok(arrays[0].clone());
717 }
718 let refs: Vec<&dyn Array> = arrays.iter().map(|a| a.as_ref()).collect();
719 arrow::compute::concat(&refs)
720 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))
721}