1use crate::query::df_graph::GraphExecutionContext;
10use crate::query::df_graph::common::{
11 ScalarKey, arrow_err, collect_all_partitions, compute_plan_properties, execute_subplan,
12 extract_scalar_key,
13};
14use crate::query::df_graph::locy_best_by::{BestByExec, SortCriterion};
15use crate::query::df_graph::locy_errors::LocyRuntimeError;
16use crate::query::df_graph::locy_explain::{
17 ProofTerm, ProvenanceAnnotation, ProvenanceStore, compute_proof_probability,
18};
19use crate::query::df_graph::locy_fold::{FoldBinding, FoldExec};
20use crate::query::df_graph::locy_priority::PriorityExec;
21use crate::query::planner::LogicalPlan;
22use arrow_array::RecordBatch;
23use arrow_row::{RowConverter, SortField};
24use arrow_schema::SchemaRef;
25use datafusion::common::JoinType;
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
29use datafusion::physical_plan::memory::MemoryStream;
30use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
31use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
32use futures::Stream;
33use parking_lot::RwLock;
34use std::any::Any;
35use std::collections::{HashMap, HashSet};
36use std::fmt;
37use std::pin::Pin;
38use std::sync::{Arc, RwLock as StdRwLock};
39use std::task::{Context, Poll};
40use std::time::{Duration, Instant};
41use uni_common::Value;
42use uni_common::core::schema::Schema as UniSchema;
43use uni_cypher::ast::Expr;
44use uni_locy::RuntimeWarning;
45use uni_store::storage::manager::StorageManager;
46
47#[derive(Debug)]
57pub struct DerivedScanEntry {
58 pub scan_index: usize,
60 pub rule_name: String,
62 pub is_self_ref: bool,
64 pub data: Arc<RwLock<Vec<RecordBatch>>>,
66 pub schema: SchemaRef,
68}
69
70#[derive(Debug, Default)]
77pub struct DerivedScanRegistry {
78 entries: Vec<DerivedScanEntry>,
79}
80
81impl DerivedScanRegistry {
82 pub fn new() -> Self {
84 Self::default()
85 }
86
87 pub fn add(&mut self, entry: DerivedScanEntry) {
89 self.entries.push(entry);
90 }
91
92 pub fn get(&self, scan_index: usize) -> Option<&DerivedScanEntry> {
94 self.entries.iter().find(|e| e.scan_index == scan_index)
95 }
96
97 pub fn write_data(&self, scan_index: usize, batches: Vec<RecordBatch>) {
99 if let Some(entry) = self.get(scan_index) {
100 let mut guard = entry.data.write();
101 *guard = batches;
102 }
103 }
104
105 pub fn entries_for_rule(&self, rule_name: &str) -> Vec<&DerivedScanEntry> {
107 self.entries
108 .iter()
109 .filter(|e| e.rule_name == rule_name)
110 .collect()
111 }
112}
113
114#[derive(Debug, Clone)]
120pub struct MonotonicFoldBinding {
121 pub fold_name: String,
122 pub kind: crate::query::df_graph::locy_fold::FoldAggKind,
123 pub input_col_index: usize,
124 pub input_col_name: Option<String>,
126}
127
128#[derive(Debug)]
134pub struct MonotonicAggState {
135 accumulators: HashMap<(Vec<ScalarKey>, String), f64>,
137 prev_snapshot: HashMap<(Vec<ScalarKey>, String), f64>,
139 bindings: Vec<MonotonicFoldBinding>,
141}
142
143impl MonotonicAggState {
144 pub fn new(bindings: Vec<MonotonicFoldBinding>) -> Self {
146 Self {
147 accumulators: HashMap::new(),
148 prev_snapshot: HashMap::new(),
149 bindings,
150 }
151 }
152
153 pub fn update(
159 &mut self,
160 key_indices: &[usize],
161 delta_batches: &[RecordBatch],
162 strict: bool,
163 ) -> DFResult<bool> {
164 use crate::query::df_graph::locy_fold::FoldAggKind;
165
166 let mut changed = false;
167 for batch in delta_batches {
168 for row_idx in 0..batch.num_rows() {
169 let group_key = extract_scalar_key(batch, key_indices, row_idx);
170 for binding in &self.bindings {
171 let idx = binding
173 .input_col_name
174 .as_ref()
175 .and_then(|name| batch.schema().index_of(name).ok())
176 .unwrap_or(binding.input_col_index);
177 if idx >= batch.num_columns() {
178 continue;
179 }
180 let col = batch.column(idx);
181 let val = extract_f64(col.as_ref(), row_idx);
182 if let Some(val) = val {
183 let map_key = (group_key.clone(), binding.fold_name.clone());
184 let entry = self
185 .accumulators
186 .entry(map_key)
187 .or_insert(binding.kind.identity().unwrap_or(0.0));
188 let old = *entry;
189 match binding.kind {
190 FoldAggKind::Sum | FoldAggKind::Count => *entry += val,
191 FoldAggKind::Max => {
192 if val > *entry {
193 *entry = val;
194 }
195 }
196 FoldAggKind::Min => {
197 if val < *entry {
198 *entry = val;
199 }
200 }
201 FoldAggKind::Nor => {
202 if strict && !(0.0..=1.0).contains(&val) {
203 return Err(datafusion::error::DataFusionError::Execution(
204 format!(
205 "strict_probability_domain: MNOR input {val} is outside [0, 1]"
206 ),
207 ));
208 }
209 if !strict && !(0.0..=1.0).contains(&val) {
210 tracing::warn!(
211 "MNOR input {val} outside [0,1], clamped to {}",
212 val.clamp(0.0, 1.0)
213 );
214 }
215 let p = val.clamp(0.0, 1.0);
216 *entry = 1.0 - (1.0 - *entry) * (1.0 - p);
217 }
218 FoldAggKind::Prod => {
219 if strict && !(0.0..=1.0).contains(&val) {
220 return Err(datafusion::error::DataFusionError::Execution(
221 format!(
222 "strict_probability_domain: MPROD input {val} is outside [0, 1]"
223 ),
224 ));
225 }
226 if !strict && !(0.0..=1.0).contains(&val) {
227 tracing::warn!(
228 "MPROD input {val} outside [0,1], clamped to {}",
229 val.clamp(0.0, 1.0)
230 );
231 }
232 let p = val.clamp(0.0, 1.0);
233 *entry *= p;
234 }
235 _ => {}
236 }
237 if (*entry - old).abs() > f64::EPSILON {
238 changed = true;
239 }
240 }
241 }
242 }
243 }
244 Ok(changed)
245 }
246
247 pub fn snapshot(&mut self) {
249 self.prev_snapshot = self.accumulators.clone();
250 }
251
252 pub fn is_stable(&self) -> bool {
254 if self.accumulators.len() != self.prev_snapshot.len() {
255 return false;
256 }
257 for (key, val) in &self.accumulators {
258 match self.prev_snapshot.get(key) {
259 Some(prev) if (*val - *prev).abs() <= f64::EPSILON => {}
260 _ => return false,
261 }
262 }
263 true
264 }
265
266 #[cfg(test)]
268 pub(crate) fn get_accumulator(&self, key: &(Vec<ScalarKey>, String)) -> Option<f64> {
269 self.accumulators.get(key).copied()
270 }
271}
272
273fn extract_f64(col: &dyn arrow_array::Array, row_idx: usize) -> Option<f64> {
275 if col.is_null(row_idx) {
276 return None;
277 }
278 if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Float64Array>() {
279 Some(arr.value(row_idx))
280 } else {
281 col.as_any()
282 .downcast_ref::<arrow_array::Int64Array>()
283 .map(|arr| arr.value(row_idx) as f64)
284 }
285}
286
287struct RowDedupState {
297 converter: RowConverter,
298 seen: HashSet<Box<[u8]>>,
299}
300
301impl RowDedupState {
302 fn try_new(schema: &SchemaRef) -> Option<Self> {
307 let fields: Vec<SortField> = schema
308 .fields()
309 .iter()
310 .map(|f| SortField::new(f.data_type().clone()))
311 .collect();
312 match RowConverter::new(fields) {
313 Ok(converter) => Some(Self {
314 converter,
315 seen: HashSet::new(),
316 }),
317 Err(e) => {
318 tracing::warn!(
319 "RowDedupState: RowConverter unsupported for schema, falling back to legacy dedup: {}",
320 e
321 );
322 None
323 }
324 }
325 }
326
327 fn ingest_existing(&mut self, facts: &[RecordBatch], _schema: &SchemaRef) {
332 self.seen.clear();
333 for batch in facts {
334 if batch.num_rows() == 0 {
335 continue;
336 }
337 let arrays: Vec<_> = batch.columns().to_vec();
338 if let Ok(rows) = self.converter.convert_columns(&arrays) {
339 for row_idx in 0..batch.num_rows() {
340 let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
341 self.seen.insert(row_bytes);
342 }
343 }
344 }
345 }
346
347 fn compute_delta(
353 &mut self,
354 candidates: &[RecordBatch],
355 schema: &SchemaRef,
356 ) -> DFResult<Vec<RecordBatch>> {
357 let mut delta_batches = Vec::new();
358 for batch in candidates {
359 if batch.num_rows() == 0 {
360 continue;
361 }
362
363 let arrays: Vec<_> = batch.columns().to_vec();
365 let rows = self.converter.convert_columns(&arrays).map_err(arrow_err)?;
366
367 let mut keep = Vec::with_capacity(batch.num_rows());
369 for row_idx in 0..batch.num_rows() {
370 let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
371 keep.push(self.seen.insert(row_bytes));
372 }
373
374 let keep_mask = arrow_array::BooleanArray::from(keep);
375 let new_cols = batch
376 .columns()
377 .iter()
378 .map(|col| {
379 arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
380 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
381 })
382 })
383 .collect::<DFResult<Vec<_>>>()?;
384
385 if new_cols.first().is_some_and(|c| !c.is_empty()) {
386 let filtered = RecordBatch::try_new(Arc::clone(schema), new_cols).map_err(|e| {
387 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
388 })?;
389 delta_batches.push(filtered);
390 }
391 }
392 Ok(delta_batches)
393 }
394}
395
396pub struct FixpointState {
406 rule_name: String,
407 facts: Vec<RecordBatch>,
408 delta: Vec<RecordBatch>,
409 schema: SchemaRef,
410 key_column_indices: Vec<usize>,
411 key_column_names: Vec<String>,
413 all_column_indices: Vec<usize>,
415 facts_bytes: usize,
417 max_derived_bytes: usize,
419 monotonic_agg: Option<MonotonicAggState>,
421 row_dedup: Option<RowDedupState>,
423 strict_probability_domain: bool,
425}
426
427impl FixpointState {
428 pub fn new(
430 rule_name: String,
431 schema: SchemaRef,
432 key_column_indices: Vec<usize>,
433 max_derived_bytes: usize,
434 monotonic_agg: Option<MonotonicAggState>,
435 strict_probability_domain: bool,
436 ) -> Self {
437 let num_cols = schema.fields().len();
438 let row_dedup = RowDedupState::try_new(&schema);
439 let key_column_names: Vec<String> = key_column_indices
440 .iter()
441 .filter_map(|&i| schema.fields().get(i).map(|f| f.name().clone()))
442 .collect();
443 Self {
444 rule_name,
445 facts: Vec::new(),
446 delta: Vec::new(),
447 schema,
448 key_column_indices,
449 key_column_names,
450 all_column_indices: (0..num_cols).collect(),
451 facts_bytes: 0,
452 max_derived_bytes,
453 monotonic_agg,
454 row_dedup,
455 strict_probability_domain,
456 }
457 }
458
459 fn reconcile_schema(&mut self, actual_schema: &SchemaRef) {
466 if self.schema.fields() != actual_schema.fields() {
467 tracing::debug!(
468 rule = %self.rule_name,
469 "Reconciling fixpoint schema from physical plan output",
470 );
471 self.schema = Arc::clone(actual_schema);
472 self.row_dedup = RowDedupState::try_new(&self.schema);
473 let new_indices: Vec<usize> = self
477 .key_column_names
478 .iter()
479 .filter_map(|name| actual_schema.index_of(name).ok())
480 .collect();
481 if new_indices.len() == self.key_column_names.len() {
482 self.key_column_indices = new_indices;
483 }
484 let num_cols = actual_schema.fields().len();
486 self.all_column_indices = (0..num_cols).collect();
487 }
488 }
489
490 pub async fn merge_delta(
494 &mut self,
495 candidates: Vec<RecordBatch>,
496 task_ctx: Option<Arc<TaskContext>>,
497 ) -> DFResult<bool> {
498 if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
499 self.delta.clear();
500 return Ok(false);
501 }
502
503 if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
507 self.reconcile_schema(&first.schema());
508 }
509
510 let candidates = round_float_columns(&candidates);
512
513 let delta = self.compute_delta(&candidates, task_ctx.as_ref()).await?;
515
516 if delta.is_empty() || delta.iter().all(|b| b.num_rows() == 0) {
517 self.delta.clear();
518 if let Some(ref mut agg) = self.monotonic_agg {
520 agg.snapshot();
521 }
522 return Ok(false);
523 }
524
525 let delta_bytes: usize = delta.iter().map(batch_byte_size).sum();
527 if self.facts_bytes + delta_bytes > self.max_derived_bytes {
528 return Err(datafusion::error::DataFusionError::Execution(
529 LocyRuntimeError::MemoryLimitExceeded {
530 rule: self.rule_name.clone(),
531 bytes: self.facts_bytes + delta_bytes,
532 limit: self.max_derived_bytes,
533 }
534 .to_string(),
535 ));
536 }
537
538 if let Some(ref mut agg) = self.monotonic_agg {
540 agg.snapshot();
541 agg.update(
542 &self.key_column_indices,
543 &delta,
544 self.strict_probability_domain,
545 )?;
546 }
547
548 self.facts_bytes += delta_bytes;
550 self.facts.extend(delta.iter().cloned());
551 self.delta = delta;
552
553 Ok(true)
554 }
555
556 async fn compute_delta(
563 &mut self,
564 candidates: &[RecordBatch],
565 task_ctx: Option<&Arc<TaskContext>>,
566 ) -> DFResult<Vec<RecordBatch>> {
567 let total_existing: usize = self.facts.iter().map(|b| b.num_rows()).sum();
568 if total_existing >= DEDUP_ANTI_JOIN_THRESHOLD
569 && let Some(ctx) = task_ctx
570 {
571 return arrow_left_anti_dedup(candidates.to_vec(), &self.facts, &self.schema, ctx)
572 .await;
573 }
574 if let Some(ref mut rd) = self.row_dedup {
575 rd.compute_delta(candidates, &self.schema)
576 } else {
577 self.compute_delta_legacy(candidates)
578 }
579 }
580
581 fn compute_delta_legacy(&self, candidates: &[RecordBatch]) -> DFResult<Vec<RecordBatch>> {
585 let mut existing: HashSet<Vec<ScalarKey>> = HashSet::new();
587 for batch in &self.facts {
588 for row_idx in 0..batch.num_rows() {
589 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
590 existing.insert(key);
591 }
592 }
593
594 let mut delta_batches = Vec::new();
595 for batch in candidates {
596 if batch.num_rows() == 0 {
597 continue;
598 }
599 let mut keep = Vec::with_capacity(batch.num_rows());
601 for row_idx in 0..batch.num_rows() {
602 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
603 keep.push(!existing.contains(&key));
604 }
605
606 for (row_idx, kept) in keep.iter_mut().enumerate() {
608 if *kept {
609 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
610 if !existing.insert(key) {
611 *kept = false;
612 }
613 }
614 }
615
616 let keep_mask = arrow_array::BooleanArray::from(keep);
617 let new_rows = batch
618 .columns()
619 .iter()
620 .map(|col| {
621 arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
622 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
623 })
624 })
625 .collect::<DFResult<Vec<_>>>()?;
626
627 if new_rows.first().is_some_and(|c| !c.is_empty()) {
628 let filtered =
629 RecordBatch::try_new(Arc::clone(&self.schema), new_rows).map_err(|e| {
630 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
631 })?;
632 delta_batches.push(filtered);
633 }
634 }
635
636 Ok(delta_batches)
637 }
638
639 pub fn is_converged(&self) -> bool {
641 let delta_empty = self.delta.is_empty() || self.delta.iter().all(|b| b.num_rows() == 0);
642 let agg_stable = self.monotonic_agg.as_ref().is_none_or(|a| a.is_stable());
643 delta_empty && agg_stable
644 }
645
646 pub fn all_facts(&self) -> &[RecordBatch] {
648 &self.facts
649 }
650
651 pub fn all_delta(&self) -> &[RecordBatch] {
653 &self.delta
654 }
655
656 pub fn into_facts(self) -> Vec<RecordBatch> {
658 self.facts
659 }
660
661 pub fn merge_best_by(
672 &mut self,
673 candidates: Vec<RecordBatch>,
674 sort_criteria: &[SortCriterion],
675 ) -> DFResult<bool> {
676 if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
677 self.delta.clear();
678 return Ok(false);
679 }
680
681 if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
683 self.reconcile_schema(&first.schema());
684 }
685
686 let candidates = round_float_columns(&candidates);
688
689 let old_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> =
691 self.build_key_criteria_map(sort_criteria);
692
693 let mut all_batches = self.facts.clone();
695 all_batches.extend(candidates);
696 let all_batches: Vec<_> = all_batches
697 .into_iter()
698 .filter(|b| b.num_rows() > 0)
699 .collect();
700 if all_batches.is_empty() {
701 self.delta.clear();
702 return Ok(false);
703 }
704
705 let combined = arrow::compute::concat_batches(&self.schema, &all_batches)
706 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
707
708 if combined.num_rows() == 0 {
709 self.delta.clear();
710 return Ok(false);
711 }
712
713 let mut sort_columns = Vec::new();
716 for &ki in &self.key_column_indices {
717 if ki >= combined.num_columns() {
718 continue;
719 }
720 sort_columns.push(arrow::compute::SortColumn {
721 values: Arc::clone(combined.column(ki)),
722 options: Some(arrow::compute::SortOptions {
723 descending: false,
724 nulls_first: false,
725 }),
726 });
727 }
728 for criterion in sort_criteria {
729 if criterion.col_index >= combined.num_columns() {
730 continue;
731 }
732 sort_columns.push(arrow::compute::SortColumn {
733 values: Arc::clone(combined.column(criterion.col_index)),
734 options: Some(arrow::compute::SortOptions {
735 descending: !criterion.ascending,
736 nulls_first: criterion.nulls_first,
737 }),
738 });
739 }
740
741 let sorted_indices =
742 arrow::compute::lexsort_to_indices(&sort_columns, None).map_err(arrow_err)?;
743 let sorted_columns: Vec<_> = combined
744 .columns()
745 .iter()
746 .map(|col| arrow::compute::take(col.as_ref(), &sorted_indices, None))
747 .collect::<Result<Vec<_>, _>>()
748 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
749 let sorted = RecordBatch::try_new(Arc::clone(&self.schema), sorted_columns)
750 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
751
752 let mut keep_indices: Vec<u32> = Vec::new();
754 let mut prev_key: Option<Vec<ScalarKey>> = None;
755 for row_idx in 0..sorted.num_rows() {
756 let key = extract_scalar_key(&sorted, &self.key_column_indices, row_idx);
757 let is_new_group = match &prev_key {
758 None => true,
759 Some(prev) => *prev != key,
760 };
761 if is_new_group {
762 keep_indices.push(row_idx as u32);
763 prev_key = Some(key);
764 }
765 }
766
767 let keep_array = arrow_array::UInt32Array::from(keep_indices);
768 let output_columns: Vec<_> = sorted
769 .columns()
770 .iter()
771 .map(|col| arrow::compute::take(col.as_ref(), &keep_array, None))
772 .collect::<Result<Vec<_>, _>>()
773 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
774 let pruned = RecordBatch::try_new(Arc::clone(&self.schema), output_columns)
775 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
776
777 let new_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> = {
779 let mut map = HashMap::new();
780 for row_idx in 0..pruned.num_rows() {
781 let key = extract_scalar_key(&pruned, &self.key_column_indices, row_idx);
782 let criteria: Vec<ScalarKey> = sort_criteria
783 .iter()
784 .flat_map(|c| extract_scalar_key(&pruned, &[c.col_index], row_idx))
785 .collect();
786 map.insert(key, criteria);
787 }
788 map
789 };
790 let changed = old_best != new_best;
791
792 tracing::debug!(
793 rule = %self.rule_name,
794 old_keys = old_best.len(),
795 new_keys = new_best.len(),
796 changed = changed,
797 "BEST BY merge"
798 );
799
800 self.facts_bytes = batch_byte_size(&pruned);
802 self.facts = vec![pruned];
803 if changed {
804 self.delta = self.facts.clone();
807 } else {
808 self.delta.clear();
809 }
810
811 self.row_dedup = RowDedupState::try_new(&self.schema);
813 if let Some(ref mut rd) = self.row_dedup {
814 rd.ingest_existing(&self.facts, &self.schema);
815 }
816
817 Ok(changed)
818 }
819
820 fn build_key_criteria_map(
822 &self,
823 sort_criteria: &[SortCriterion],
824 ) -> HashMap<Vec<ScalarKey>, Vec<ScalarKey>> {
825 let mut map = HashMap::new();
826 for batch in &self.facts {
827 for row_idx in 0..batch.num_rows() {
828 let key = extract_scalar_key(batch, &self.key_column_indices, row_idx);
829 let criteria: Vec<ScalarKey> = sort_criteria
830 .iter()
831 .flat_map(|c| extract_scalar_key(batch, &[c.col_index], row_idx))
832 .collect();
833 map.insert(key, criteria);
834 }
835 }
836 map
837 }
838}
839
840fn batch_byte_size(batch: &RecordBatch) -> usize {
842 batch
843 .columns()
844 .iter()
845 .map(|col| col.get_buffer_memory_size())
846 .sum()
847}
848
849fn round_float_columns(batches: &[RecordBatch]) -> Vec<RecordBatch> {
855 batches
856 .iter()
857 .map(|batch| {
858 let schema = batch.schema();
859 let has_float = schema
860 .fields()
861 .iter()
862 .any(|f| *f.data_type() == arrow_schema::DataType::Float64);
863 if !has_float {
864 return batch.clone();
865 }
866
867 let columns: Vec<arrow_array::ArrayRef> = batch
868 .columns()
869 .iter()
870 .enumerate()
871 .map(|(i, col)| {
872 if *schema.field(i).data_type() == arrow_schema::DataType::Float64 {
873 let arr = col
874 .as_any()
875 .downcast_ref::<arrow_array::Float64Array>()
876 .unwrap();
877 let rounded: arrow_array::Float64Array = arr
878 .iter()
879 .map(|v| v.map(|f| (f * 1e12).round() / 1e12))
880 .collect();
881 Arc::new(rounded) as arrow_array::ArrayRef
882 } else {
883 Arc::clone(col)
884 }
885 })
886 .collect();
887
888 RecordBatch::try_new(schema, columns).unwrap_or_else(|_| batch.clone())
889 })
890 .collect()
891}
892
893const DEDUP_ANTI_JOIN_THRESHOLD: usize = 300;
903
904async fn arrow_left_anti_dedup(
909 candidates: Vec<RecordBatch>,
910 existing: &[RecordBatch],
911 schema: &SchemaRef,
912 task_ctx: &Arc<TaskContext>,
913) -> DFResult<Vec<RecordBatch>> {
914 if existing.is_empty() || existing.iter().all(|b| b.num_rows() == 0) {
915 return Ok(candidates);
916 }
917
918 let left: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(candidates, Arc::clone(schema)));
919 let right: Arc<dyn ExecutionPlan> =
920 Arc::new(InMemoryExec::new(existing.to_vec(), Arc::clone(schema)));
921
922 let on: Vec<(
923 Arc<dyn datafusion::physical_plan::PhysicalExpr>,
924 Arc<dyn datafusion::physical_plan::PhysicalExpr>,
925 )> = schema
926 .fields()
927 .iter()
928 .enumerate()
929 .map(|(i, field)| {
930 let l: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
931 datafusion::physical_plan::expressions::Column::new(field.name(), i),
932 );
933 let r: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
934 datafusion::physical_plan::expressions::Column::new(field.name(), i),
935 );
936 (l, r)
937 })
938 .collect();
939
940 if on.is_empty() {
941 return Ok(vec![]);
942 }
943
944 let join = HashJoinExec::try_new(
945 left,
946 right,
947 on,
948 None,
949 &JoinType::LeftAnti,
950 None,
951 PartitionMode::CollectLeft,
952 datafusion::common::NullEquality::NullEqualsNull,
953 )?;
954
955 let join_arc: Arc<dyn ExecutionPlan> = Arc::new(join);
956 collect_all_partitions(&join_arc, task_ctx.clone()).await
957}
958
959#[derive(Debug, Clone)]
965pub struct IsRefBinding {
966 pub derived_scan_index: usize,
968 pub rule_name: String,
970 pub is_self_ref: bool,
972 pub negated: bool,
974 pub anti_join_cols: Vec<(String, String)>,
980 pub target_has_prob: bool,
982 pub target_prob_col: Option<String>,
984 pub provenance_join_cols: Vec<(String, String)>,
989}
990
991#[derive(Debug)]
993pub struct FixpointClausePlan {
994 pub body_logical: LogicalPlan,
996 pub is_ref_bindings: Vec<IsRefBinding>,
998 pub priority: Option<i64>,
1000 pub along_bindings: Vec<String>,
1002}
1003
1004#[derive(Debug)]
1006pub struct FixpointRulePlan {
1007 pub name: String,
1009 pub clauses: Vec<FixpointClausePlan>,
1011 pub yield_schema: SchemaRef,
1013 pub key_column_indices: Vec<usize>,
1015 pub priority: Option<i64>,
1017 pub has_fold: bool,
1019 pub fold_bindings: Vec<FoldBinding>,
1021 pub having: Vec<Expr>,
1023 pub has_best_by: bool,
1025 pub best_by_criteria: Vec<SortCriterion>,
1027 pub has_priority: bool,
1029 pub deterministic: bool,
1033 pub prob_column_name: Option<String>,
1035}
1036
1037#[expect(clippy::too_many_arguments, reason = "Fixpoint loop needs all context")]
1046async fn run_fixpoint_loop(
1047 rules: Vec<FixpointRulePlan>,
1048 max_iterations: usize,
1049 timeout: Duration,
1050 graph_ctx: Arc<GraphExecutionContext>,
1051 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
1052 storage: Arc<StorageManager>,
1053 schema_info: Arc<UniSchema>,
1054 params: HashMap<String, Value>,
1055 registry: Arc<DerivedScanRegistry>,
1056 output_schema: SchemaRef,
1057 max_derived_bytes: usize,
1058 derivation_tracker: Option<Arc<ProvenanceStore>>,
1059 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
1060 strict_probability_domain: bool,
1061 probability_epsilon: f64,
1062 exact_probability: bool,
1063 max_bdd_variables: usize,
1064 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
1065 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
1066 top_k_proofs: usize,
1067 timeout_flag: Arc<std::sync::atomic::AtomicBool>,
1068) -> DFResult<Vec<RecordBatch>> {
1069 let start = Instant::now();
1070 let task_ctx = session_ctx.read().task_ctx();
1071
1072 let mut states: Vec<FixpointState> = rules
1074 .iter()
1075 .map(|rule| {
1076 let monotonic_agg = if !rule.fold_bindings.is_empty() {
1077 let bindings: Vec<MonotonicFoldBinding> = rule
1078 .fold_bindings
1079 .iter()
1080 .map(|fb| MonotonicFoldBinding {
1081 fold_name: fb.output_name.clone(),
1082 kind: fb.kind.clone(),
1083 input_col_index: fb.input_col_index,
1084 input_col_name: fb.input_col_name.clone(),
1085 })
1086 .collect();
1087 Some(MonotonicAggState::new(bindings))
1088 } else {
1089 None
1090 };
1091 FixpointState::new(
1092 rule.name.clone(),
1093 Arc::clone(&rule.yield_schema),
1094 rule.key_column_indices.clone(),
1095 max_derived_bytes,
1096 monotonic_agg,
1097 strict_probability_domain,
1098 )
1099 })
1100 .collect();
1101
1102 let mut converged = false;
1104 let mut total_iters = 0usize;
1105 for iteration in 0..max_iterations {
1106 total_iters = iteration + 1;
1107 tracing::debug!("fixpoint iteration {}", iteration);
1108 let mut any_changed = false;
1109
1110 for rule_idx in 0..rules.len() {
1111 let rule = &rules[rule_idx];
1112
1113 update_derived_scan_handles(®istry, &states, rule_idx, &rules);
1115
1116 let mut all_candidates = Vec::new();
1118 let mut clause_candidates: Vec<Vec<RecordBatch>> = Vec::new();
1119 for clause in &rule.clauses {
1120 let mut batches = execute_subplan(
1121 &clause.body_logical,
1122 ¶ms,
1123 &HashMap::new(),
1124 &graph_ctx,
1125 &session_ctx,
1126 &storage,
1127 &schema_info,
1128 )
1129 .await?;
1130 for binding in &clause.is_ref_bindings {
1132 if binding.negated
1133 && !binding.anti_join_cols.is_empty()
1134 && let Some(entry) = registry.get(binding.derived_scan_index)
1135 {
1136 let neg_facts = entry.data.read().clone();
1137 if !neg_facts.is_empty() {
1138 if binding.target_has_prob && rule.prob_column_name.is_some() {
1139 let complement_col =
1141 format!("__prob_complement_{}", binding.rule_name);
1142 if let Some(prob_col) = &binding.target_prob_col {
1143 batches = apply_prob_complement_composite(
1144 batches,
1145 &neg_facts,
1146 &binding.anti_join_cols,
1147 prob_col,
1148 &complement_col,
1149 )?;
1150 } else {
1151 batches = apply_anti_join_composite(
1153 batches,
1154 &neg_facts,
1155 &binding.anti_join_cols,
1156 )?;
1157 }
1158 } else {
1159 batches = apply_anti_join_composite(
1161 batches,
1162 &neg_facts,
1163 &binding.anti_join_cols,
1164 )?;
1165 }
1166 }
1167 }
1168 }
1169 let complement_cols: Vec<String> = if !batches.is_empty() {
1171 batches[0]
1172 .schema()
1173 .fields()
1174 .iter()
1175 .filter(|f| f.name().starts_with("__prob_complement_"))
1176 .map(|f| f.name().clone())
1177 .collect()
1178 } else {
1179 vec![]
1180 };
1181 if !complement_cols.is_empty() {
1182 batches = multiply_prob_factors(
1183 batches,
1184 rule.prob_column_name.as_deref(),
1185 &complement_cols,
1186 )?;
1187 }
1188
1189 clause_candidates.push(batches.clone());
1190 all_candidates.extend(batches);
1191 }
1192
1193 let changed = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
1197 states[rule_idx].merge_best_by(all_candidates, &rule.best_by_criteria)?
1198 } else {
1199 states[rule_idx]
1200 .merge_delta(all_candidates, Some(Arc::clone(&task_ctx)))
1201 .await?
1202 };
1203 if changed {
1204 any_changed = true;
1205 if let Some(ref tracker) = derivation_tracker {
1207 record_provenance(
1208 tracker,
1209 rule,
1210 &states[rule_idx],
1211 &clause_candidates,
1212 iteration,
1213 ®istry,
1214 top_k_proofs,
1215 );
1216 }
1217 }
1218 }
1219
1220 if !any_changed && states.iter().all(|s| s.is_converged()) {
1222 tracing::debug!("fixpoint converged after {} iterations", iteration + 1);
1223 converged = true;
1224 break;
1225 }
1226
1227 if start.elapsed() > timeout {
1229 tracing::warn!(
1230 "fixpoint timeout after {} iterations; returning partial results",
1231 iteration + 1,
1232 );
1233 timeout_flag.store(true, std::sync::atomic::Ordering::Relaxed);
1234 break;
1235 }
1236 }
1237
1238 if let Ok(mut counts) = iteration_counts.write() {
1240 for rule in &rules {
1241 counts.insert(rule.name.clone(), total_iters);
1242 }
1243 }
1244
1245 if !converged && !timeout_flag.load(std::sync::atomic::Ordering::Relaxed) {
1248 tracing::warn!(
1249 "fixpoint did not converge after {max_iterations} iterations; returning partial results",
1250 );
1251 timeout_flag.store(true, std::sync::atomic::Ordering::Relaxed);
1252 }
1253
1254 let task_ctx = session_ctx.read().task_ctx();
1256 let mut all_output = Vec::new();
1257
1258 for (rule_idx, state) in states.into_iter().enumerate() {
1259 let rule = &rules[rule_idx];
1260 let mut facts = state.into_facts();
1261 if facts.is_empty() {
1262 continue;
1263 }
1264
1265 let shared_info = if let Some(ref tracker) = derivation_tracker {
1267 detect_shared_lineage(rule, &facts, tracker, &warnings_slot)
1268 } else {
1269 None
1270 };
1271
1272 if exact_probability
1274 && let Some(ref info) = shared_info
1275 && let Some(ref tracker) = derivation_tracker
1276 {
1277 facts = apply_exact_wmc(
1278 facts,
1279 rule,
1280 info,
1281 tracker,
1282 max_bdd_variables,
1283 &warnings_slot,
1284 &approximate_slot,
1285 )?;
1286 }
1287
1288 let processed = apply_post_fixpoint_chain(
1289 facts,
1290 rule,
1291 &task_ctx,
1292 strict_probability_domain,
1293 probability_epsilon,
1294 )
1295 .await?;
1296 all_output.extend(processed);
1297 }
1298
1299 if all_output.is_empty() {
1301 all_output.push(RecordBatch::new_empty(output_schema));
1302 }
1303
1304 Ok(all_output)
1305}
1306
1307fn record_provenance(
1316 tracker: &Arc<ProvenanceStore>,
1317 rule: &FixpointRulePlan,
1318 state: &FixpointState,
1319 clause_candidates: &[Vec<RecordBatch>],
1320 iteration: usize,
1321 registry: &Arc<DerivedScanRegistry>,
1322 top_k_proofs: usize,
1323) {
1324 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1325
1326 let base_probs = if top_k_proofs > 0 {
1328 tracker.base_fact_probs()
1329 } else {
1330 HashMap::new()
1331 };
1332
1333 for delta_batch in state.all_delta() {
1334 for row_idx in 0..delta_batch.num_rows() {
1335 let row_hash = format!(
1336 "{:?}",
1337 extract_scalar_key(delta_batch, &all_indices, row_idx)
1338 )
1339 .into_bytes();
1340 let fact_row = batch_row_to_value_map(delta_batch, row_idx);
1341 let clause_index =
1342 find_clause_for_row(delta_batch, row_idx, &all_indices, clause_candidates);
1343
1344 let support = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
1345
1346 let proof_probability = if top_k_proofs > 0 {
1347 compute_proof_probability(&support, &base_probs)
1348 } else {
1349 None
1350 };
1351
1352 let entry = ProvenanceAnnotation {
1353 rule_name: rule.name.clone(),
1354 clause_index,
1355 support,
1356 along_values: {
1357 let along_names: Vec<String> = rule
1358 .clauses
1359 .get(clause_index)
1360 .map(|c| c.along_bindings.clone())
1361 .unwrap_or_default();
1362 along_names
1363 .iter()
1364 .filter_map(|name| fact_row.get(name).map(|v| (name.clone(), v.clone())))
1365 .collect()
1366 },
1367 iteration,
1368 fact_row,
1369 proof_probability,
1370 };
1371 if top_k_proofs > 0 {
1372 tracker.record_top_k(row_hash, entry, top_k_proofs);
1373 } else {
1374 tracker.record(row_hash, entry);
1375 }
1376 }
1377 }
1378}
1379
1380fn collect_is_ref_inputs(
1386 rule: &FixpointRulePlan,
1387 clause_index: usize,
1388 delta_batch: &RecordBatch,
1389 row_idx: usize,
1390 registry: &Arc<DerivedScanRegistry>,
1391) -> Vec<ProofTerm> {
1392 let clause = match rule.clauses.get(clause_index) {
1393 Some(c) => c,
1394 None => return vec![],
1395 };
1396
1397 let mut inputs = Vec::new();
1398 let delta_schema = delta_batch.schema();
1399
1400 for binding in &clause.is_ref_bindings {
1401 if binding.negated {
1402 continue;
1403 }
1404 if binding.provenance_join_cols.is_empty() {
1405 continue;
1406 }
1407
1408 let body_values: Vec<(String, ScalarKey)> = binding
1410 .provenance_join_cols
1411 .iter()
1412 .filter_map(|(body_col, _derived_col)| {
1413 let col_idx = delta_schema
1414 .fields()
1415 .iter()
1416 .position(|f| f.name() == body_col)?;
1417 let key = extract_scalar_key(delta_batch, &[col_idx], row_idx);
1418 Some((body_col.clone(), key.into_iter().next()?))
1419 })
1420 .collect();
1421
1422 if body_values.len() != binding.provenance_join_cols.len() {
1423 continue;
1424 }
1425
1426 let entry = match registry.get(binding.derived_scan_index) {
1428 Some(e) => e,
1429 None => continue,
1430 };
1431 let source_batches = entry.data.read();
1432 let source_schema = &entry.schema;
1433
1434 for src_batch in source_batches.iter() {
1436 let all_src_indices: Vec<usize> = (0..src_batch.num_columns()).collect();
1437 for src_row in 0..src_batch.num_rows() {
1438 let matches = binding.provenance_join_cols.iter().enumerate().all(
1439 |(i, (_body_col, derived_col))| {
1440 let src_col_idx = source_schema
1441 .fields()
1442 .iter()
1443 .position(|f| f.name() == derived_col);
1444 match src_col_idx {
1445 Some(idx) => {
1446 let src_key = extract_scalar_key(src_batch, &[idx], src_row);
1447 src_key.first() == Some(&body_values[i].1)
1448 }
1449 None => false,
1450 }
1451 },
1452 );
1453 if matches {
1454 let fact_hash = format!(
1455 "{:?}",
1456 extract_scalar_key(src_batch, &all_src_indices, src_row)
1457 )
1458 .into_bytes();
1459 inputs.push(ProofTerm {
1460 source_rule: binding.rule_name.clone(),
1461 base_fact_id: fact_hash,
1462 });
1463 }
1464 }
1465 }
1466 }
1467
1468 inputs
1469}
1470
1471#[expect(
1490 dead_code,
1491 reason = "Fields accessed via SharedLineageInfo in detect_shared_lineage"
1492)]
1493pub(crate) struct SharedGroupRow {
1494 pub fact_hash: Vec<u8>,
1495 pub lineage: HashSet<Vec<u8>>,
1496}
1497
1498pub(crate) struct SharedLineageInfo {
1500 pub shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>>,
1502}
1503
1504fn fact_hash_key(batch: &RecordBatch, all_indices: &[usize], row_idx: usize) -> Vec<u8> {
1506 format!("{:?}", extract_scalar_key(batch, all_indices, row_idx)).into_bytes()
1507}
1508
1509fn detect_shared_lineage(
1512 rule: &FixpointRulePlan,
1513 pre_fold_facts: &[RecordBatch],
1514 tracker: &Arc<ProvenanceStore>,
1515 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1516) -> Option<SharedLineageInfo> {
1517 use crate::query::df_graph::locy_fold::FoldAggKind;
1518 use uni_locy::{RuntimeWarning, RuntimeWarningCode};
1519
1520 let has_prob_fold = rule
1522 .fold_bindings
1523 .iter()
1524 .any(|fb| matches!(fb.kind, FoldAggKind::Nor | FoldAggKind::Prod));
1525 if !has_prob_fold {
1526 return None;
1527 }
1528
1529 let key_indices = &rule.key_column_indices;
1531 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1532
1533 let mut groups: HashMap<Vec<ScalarKey>, Vec<Vec<u8>>> = HashMap::new();
1534 for batch in pre_fold_facts {
1535 for row_idx in 0..batch.num_rows() {
1536 let key = extract_scalar_key(batch, key_indices, row_idx);
1537 let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
1538 groups.entry(key).or_default().push(fact_hash);
1539 }
1540 }
1541
1542 let mut shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>> = HashMap::new();
1543 let mut any_shared = false;
1544
1545 for (key, fact_hashes) in &groups {
1547 if fact_hashes.len() < 2 {
1548 continue;
1549 }
1550
1551 let mut has_inputs = false;
1553 let mut per_row_bases: Vec<HashSet<Vec<u8>>> = Vec::new();
1554 for fh in fact_hashes {
1555 let bases = compute_lineage(fh, tracker, &mut HashSet::new());
1556 if let Some(entry) = tracker.lookup(fh)
1557 && !entry.support.is_empty()
1558 {
1559 has_inputs = true;
1560 }
1561 per_row_bases.push(bases);
1562 }
1563
1564 let shared_found = if has_inputs {
1565 let mut found = false;
1567 'outer: for i in 0..per_row_bases.len() {
1568 for j in (i + 1)..per_row_bases.len() {
1569 if !per_row_bases[i].is_disjoint(&per_row_bases[j]) {
1570 found = true;
1571 break 'outer;
1572 }
1573 }
1574 }
1575 found
1576 } else {
1577 fact_hashes.iter().any(|fh| {
1580 tracker.lookup(fh).is_some_and(|entry| {
1581 rule.clauses
1582 .get(entry.clause_index)
1583 .is_some_and(|clause| clause.is_ref_bindings.iter().any(|b| !b.negated))
1584 })
1585 })
1586 };
1587
1588 if shared_found {
1589 any_shared = true;
1590 let rows: Vec<SharedGroupRow> = fact_hashes
1592 .iter()
1593 .zip(per_row_bases.into_iter())
1594 .map(|(fh, bases)| SharedGroupRow {
1595 fact_hash: fh.clone(),
1596 lineage: bases,
1597 })
1598 .collect();
1599 shared_groups.insert(key.clone(), rows);
1600 }
1601 }
1602
1603 {
1609 let mut input_to_groups: HashMap<Vec<u8>, HashSet<Vec<ScalarKey>>> = HashMap::new();
1610 for (key, fact_hashes) in &groups {
1611 for fh in fact_hashes {
1612 if let Some(entry) = tracker.lookup(fh) {
1613 for input in &entry.support {
1614 input_to_groups
1615 .entry(input.base_fact_id.clone())
1616 .or_default()
1617 .insert(key.clone());
1618 }
1619 }
1620 }
1621 }
1622 let has_cross_group = input_to_groups.values().any(|g| g.len() > 1);
1623 if has_cross_group && let Ok(mut warnings) = warnings_slot.write() {
1624 let already_warned = warnings.iter().any(|w| {
1625 w.code == RuntimeWarningCode::CrossGroupCorrelationNotExact
1626 && w.rule_name == rule.name
1627 });
1628 if !already_warned {
1629 warnings.push(RuntimeWarning {
1630 code: RuntimeWarningCode::CrossGroupCorrelationNotExact,
1631 message: format!(
1632 "Rule '{}': IS-ref base facts are shared across different KEY \
1633 groups. BDD corrects per-group probabilities but cannot account \
1634 for cross-group correlations.",
1635 rule.name
1636 ),
1637 rule_name: rule.name.clone(),
1638 variable_count: None,
1639 key_group: None,
1640 });
1641 }
1642 }
1643 }
1644
1645 if any_shared {
1646 if let Ok(mut warnings) = warnings_slot.write() {
1647 let already_warned = warnings.iter().any(|w| {
1648 w.code == RuntimeWarningCode::SharedProbabilisticDependency
1649 && w.rule_name == rule.name
1650 });
1651 if !already_warned {
1652 warnings.push(RuntimeWarning {
1653 code: RuntimeWarningCode::SharedProbabilisticDependency,
1654 message: format!(
1655 "Rule '{}' aggregates with MNOR/MPROD but some proof paths \
1656 share intermediate facts, violating the independence assumption. \
1657 Results may overestimate probability.",
1658 rule.name
1659 ),
1660 rule_name: rule.name.clone(),
1661 variable_count: None,
1662 key_group: None,
1663 });
1664 }
1665 }
1666 Some(SharedLineageInfo { shared_groups })
1667 } else {
1668 None
1669 }
1670}
1671
1672pub(crate) fn record_and_detect_lineage_nonrecursive(
1680 rule: &FixpointRulePlan,
1681 tagged_clause_facts: &[(usize, Vec<RecordBatch>)],
1682 tracker: &Arc<ProvenanceStore>,
1683 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1684 registry: &Arc<DerivedScanRegistry>,
1685 top_k_proofs: usize,
1686) -> Option<SharedLineageInfo> {
1687 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1688
1689 let base_probs = if top_k_proofs > 0 {
1691 tracker.base_fact_probs()
1692 } else {
1693 HashMap::new()
1694 };
1695
1696 for (clause_index, batches) in tagged_clause_facts {
1698 for batch in batches {
1699 for row_idx in 0..batch.num_rows() {
1700 let row_hash = fact_hash_key(batch, &all_indices, row_idx);
1701 let fact_row = batch_row_to_value_map(batch, row_idx);
1702
1703 let support = collect_is_ref_inputs(rule, *clause_index, batch, row_idx, registry);
1704
1705 let proof_probability = if top_k_proofs > 0 {
1706 compute_proof_probability(&support, &base_probs)
1707 } else {
1708 None
1709 };
1710
1711 let entry = ProvenanceAnnotation {
1712 rule_name: rule.name.clone(),
1713 clause_index: *clause_index,
1714 support,
1715 along_values: {
1716 let along_names: Vec<String> = rule
1717 .clauses
1718 .get(*clause_index)
1719 .map(|c| c.along_bindings.clone())
1720 .unwrap_or_default();
1721 along_names
1722 .iter()
1723 .filter_map(|name| {
1724 fact_row.get(name).map(|v| (name.clone(), v.clone()))
1725 })
1726 .collect()
1727 },
1728 iteration: 0,
1729 fact_row,
1730 proof_probability,
1731 };
1732 if top_k_proofs > 0 {
1733 tracker.record_top_k(row_hash, entry, top_k_proofs);
1734 } else {
1735 tracker.record(row_hash, entry);
1736 }
1737 }
1738 }
1739 }
1740
1741 let all_facts: Vec<RecordBatch> = tagged_clause_facts
1743 .iter()
1744 .flat_map(|(_, batches)| batches.iter().cloned())
1745 .collect();
1746 detect_shared_lineage(rule, &all_facts, tracker, warnings_slot)
1747}
1748
1749pub(crate) fn apply_exact_wmc(
1757 pre_fold_facts: Vec<RecordBatch>,
1758 rule: &FixpointRulePlan,
1759 shared_info: &SharedLineageInfo,
1760 tracker: &Arc<ProvenanceStore>,
1761 max_bdd_variables: usize,
1762 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1763 approximate_slot: &Arc<StdRwLock<HashMap<String, Vec<String>>>>,
1764) -> DFResult<Vec<RecordBatch>> {
1765 use crate::query::df_graph::locy_bdd::{SemiringOp, weighted_model_count};
1766 use crate::query::df_graph::locy_fold::FoldAggKind;
1767 use uni_locy::{RuntimeWarning, RuntimeWarningCode};
1768
1769 let prob_fold = rule
1771 .fold_bindings
1772 .iter()
1773 .find(|fb| matches!(fb.kind, FoldAggKind::Nor | FoldAggKind::Prod));
1774 let prob_fold = match prob_fold {
1775 Some(f) => f,
1776 None => return Ok(pre_fold_facts),
1777 };
1778 let semiring_op = if matches!(prob_fold.kind, FoldAggKind::Nor) {
1779 SemiringOp::Disjunction
1780 } else {
1781 SemiringOp::Conjunction
1782 };
1783 let prob_col_idx = prob_fold.input_col_index;
1784 let prob_col_name = rule.yield_schema.field(prob_col_idx).name().clone();
1785
1786 let key_indices = &rule.key_column_indices;
1787 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1788
1789 let shared_keys: HashSet<Vec<ScalarKey>> = shared_info.shared_groups.keys().cloned().collect();
1791
1792 struct GroupAccum {
1795 base_facts: Vec<HashSet<Vec<u8>>>,
1796 base_probs: HashMap<Vec<u8>, f64>,
1797 representative: (usize, usize),
1799 row_locations: Vec<(usize, usize)>,
1800 }
1801
1802 let mut group_accums: HashMap<Vec<ScalarKey>, GroupAccum> = HashMap::new();
1803 let mut non_shared_rows: Vec<(usize, usize)> = Vec::new(); for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
1806 for row_idx in 0..batch.num_rows() {
1807 let key = extract_scalar_key(batch, key_indices, row_idx);
1808 if shared_keys.contains(&key) {
1809 let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
1810 let bases = compute_lineage(&fact_hash, tracker, &mut HashSet::new());
1811
1812 let accum = group_accums.entry(key).or_insert_with(|| GroupAccum {
1813 base_facts: Vec::new(),
1814 base_probs: HashMap::new(),
1815 representative: (batch_idx, row_idx),
1816 row_locations: Vec::new(),
1817 });
1818
1819 for bf in &bases {
1821 if !accum.base_probs.contains_key(bf)
1822 && let Some(entry) = tracker.lookup(bf)
1823 && let Some(val) = entry.fact_row.get(&prob_col_name)
1824 && let Some(p) = value_to_f64(val)
1825 {
1826 accum.base_probs.insert(bf.clone(), p);
1827 }
1828 }
1829
1830 accum.base_facts.push(bases);
1831 accum.row_locations.push((batch_idx, row_idx));
1832 } else {
1833 non_shared_rows.push((batch_idx, row_idx));
1834 }
1835 }
1836 }
1837
1838 let mut keep_rows: HashSet<(usize, usize)> = HashSet::new();
1841 let mut overrides: HashMap<(usize, usize), f64> = HashMap::new();
1843
1844 for &loc in &non_shared_rows {
1846 keep_rows.insert(loc);
1847 }
1848
1849 for (key, accum) in &group_accums {
1850 let bdd_result = weighted_model_count(
1851 &accum.base_facts,
1852 &accum.base_probs,
1853 semiring_op,
1854 max_bdd_variables,
1855 );
1856
1857 if bdd_result.approximated {
1858 if let Ok(mut warnings) = warnings_slot.write() {
1860 let key_desc = format!("{key:?}");
1861 let already_warned = warnings.iter().any(|w| {
1862 w.code == RuntimeWarningCode::BddLimitExceeded
1863 && w.rule_name == rule.name
1864 && w.key_group.as_deref() == Some(&key_desc)
1865 });
1866 if !already_warned {
1867 warnings.push(RuntimeWarning {
1868 code: RuntimeWarningCode::BddLimitExceeded,
1869 message: format!(
1870 "Rule '{}': BDD variable limit exceeded ({} > {}). \
1871 Falling back to independence-mode result.",
1872 rule.name, bdd_result.variable_count, max_bdd_variables
1873 ),
1874 rule_name: rule.name.clone(),
1875 variable_count: Some(bdd_result.variable_count),
1876 key_group: Some(key_desc),
1877 });
1878 }
1879 }
1880 if let Ok(mut approx) = approximate_slot.write() {
1881 let key_desc = format!("{key:?}");
1882 approx.entry(rule.name.clone()).or_default().push(key_desc);
1883 }
1884 for &loc in &accum.row_locations {
1886 keep_rows.insert(loc);
1887 }
1888 } else {
1889 keep_rows.insert(accum.representative);
1891 overrides.insert(accum.representative, bdd_result.probability);
1892 }
1893 }
1894
1895 let mut result_batches = Vec::new();
1897 for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
1898 let kept_indices: Vec<usize> = (0..batch.num_rows())
1899 .filter(|&row_idx| keep_rows.contains(&(batch_idx, row_idx)))
1900 .collect();
1901
1902 if kept_indices.is_empty() {
1903 continue;
1904 }
1905
1906 let indices = arrow::array::UInt32Array::from(
1907 kept_indices.iter().map(|&i| i as u32).collect::<Vec<_>>(),
1908 );
1909 let mut columns: Vec<arrow::array::ArrayRef> = batch
1910 .columns()
1911 .iter()
1912 .map(|col| arrow::compute::take(col, &indices, None))
1913 .collect::<Result<Vec<_>, _>>()
1914 .map_err(arrow_err)?;
1915
1916 let override_map: Vec<Option<f64>> = kept_indices
1918 .iter()
1919 .map(|&row_idx| overrides.get(&(batch_idx, row_idx)).copied())
1920 .collect();
1921
1922 if override_map.iter().any(|o| o.is_some()) && prob_col_idx < columns.len() {
1923 let existing_prob = columns[prob_col_idx]
1925 .as_any()
1926 .downcast_ref::<arrow::array::Float64Array>();
1927 let new_values: Vec<f64> = override_map
1928 .iter()
1929 .enumerate()
1930 .map(|(i, ov)| match ov {
1931 Some(p) => *p,
1932 None => existing_prob.map(|arr| arr.value(i)).unwrap_or(0.0),
1933 })
1934 .collect();
1935 columns[prob_col_idx] = Arc::new(arrow::array::Float64Array::from(new_values));
1936 }
1937
1938 let result_batch = RecordBatch::try_new(batch.schema(), columns).map_err(arrow_err)?;
1939 result_batches.push(result_batch);
1940 }
1941
1942 Ok(result_batches)
1943}
1944
1945fn value_to_f64(val: &uni_common::Value) -> Option<f64> {
1947 match val {
1948 uni_common::Value::Float(f) => Some(*f),
1949 uni_common::Value::Int(i) => Some(*i as f64),
1950 _ => None,
1951 }
1952}
1953
1954fn compute_lineage(
1961 fact_hash: &[u8],
1962 tracker: &Arc<ProvenanceStore>,
1963 visited: &mut HashSet<Vec<u8>>,
1964) -> HashSet<Vec<u8>> {
1965 if !visited.insert(fact_hash.to_vec()) {
1966 return HashSet::new(); }
1968
1969 match tracker.lookup(fact_hash) {
1970 Some(entry) if !entry.support.is_empty() => {
1971 let mut bases = HashSet::new();
1972 for input in &entry.support {
1973 let child_bases = compute_lineage(&input.base_fact_id, tracker, visited);
1974 bases.extend(child_bases);
1975 }
1976 bases
1977 }
1978 _ => {
1979 let mut set = HashSet::new();
1981 set.insert(fact_hash.to_vec());
1982 set
1983 }
1984 }
1985}
1986
1987fn find_clause_for_row(
1992 delta_batch: &RecordBatch,
1993 row_idx: usize,
1994 all_indices: &[usize],
1995 clause_candidates: &[Vec<RecordBatch>],
1996) -> usize {
1997 let target_key = extract_scalar_key(delta_batch, all_indices, row_idx);
1998 for (clause_idx, batches) in clause_candidates.iter().enumerate() {
1999 for batch in batches {
2000 if batch.num_columns() != all_indices.len() {
2001 continue;
2002 }
2003 for r in 0..batch.num_rows() {
2004 if extract_scalar_key(batch, all_indices, r) == target_key {
2005 return clause_idx;
2006 }
2007 }
2008 }
2009 }
2010 0
2011}
2012
2013fn batch_row_to_value_map(
2015 batch: &RecordBatch,
2016 row_idx: usize,
2017) -> std::collections::HashMap<String, Value> {
2018 use uni_store::storage::arrow_convert::arrow_to_value;
2019
2020 let schema = batch.schema();
2021 schema
2022 .fields()
2023 .iter()
2024 .enumerate()
2025 .map(|(col_idx, field)| {
2026 let col = batch.column(col_idx);
2027 let val = arrow_to_value(col.as_ref(), row_idx, None);
2028 (field.name().clone(), val)
2029 })
2030 .collect()
2031}
2032
2033pub fn apply_anti_join(
2038 batches: Vec<RecordBatch>,
2039 neg_facts: &[RecordBatch],
2040 left_col: &str,
2041 right_col: &str,
2042) -> datafusion::error::Result<Vec<RecordBatch>> {
2043 use arrow::compute::filter_record_batch;
2044 use arrow_array::{Array as _, BooleanArray, UInt64Array};
2045
2046 let mut banned: std::collections::HashSet<u64> = std::collections::HashSet::new();
2048 for batch in neg_facts {
2049 let Ok(idx) = batch.schema().index_of(right_col) else {
2050 continue;
2051 };
2052 let arr = batch.column(idx);
2053 let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2054 continue;
2055 };
2056 for i in 0..vids.len() {
2057 if !vids.is_null(i) {
2058 banned.insert(vids.value(i));
2059 }
2060 }
2061 }
2062
2063 if banned.is_empty() {
2064 return Ok(batches);
2065 }
2066
2067 let mut result = Vec::new();
2069 for batch in batches {
2070 let Ok(idx) = batch.schema().index_of(left_col) else {
2071 result.push(batch);
2072 continue;
2073 };
2074 let arr = batch.column(idx);
2075 let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2076 result.push(batch);
2077 continue;
2078 };
2079 let keep: Vec<bool> = (0..vids.len())
2080 .map(|i| vids.is_null(i) || !banned.contains(&vids.value(i)))
2081 .collect();
2082 let keep_arr = BooleanArray::from(keep);
2083 let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
2084 if filtered.num_rows() > 0 {
2085 result.push(filtered);
2086 }
2087 }
2088 Ok(result)
2089}
2090
2091pub fn apply_prob_complement(
2100 batches: Vec<RecordBatch>,
2101 neg_facts: &[RecordBatch],
2102 left_col: &str,
2103 right_col: &str,
2104 prob_col: &str,
2105 complement_col_name: &str,
2106) -> datafusion::error::Result<Vec<RecordBatch>> {
2107 use arrow_array::{Array as _, Float64Array, UInt64Array};
2108
2109 let mut prob_map: std::collections::HashMap<u64, f64> = std::collections::HashMap::new();
2111 for batch in neg_facts {
2112 let Ok(vid_idx) = batch.schema().index_of(right_col) else {
2113 continue;
2114 };
2115 let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
2116 continue;
2117 };
2118 let Some(vids) = batch.column(vid_idx).as_any().downcast_ref::<UInt64Array>() else {
2119 continue;
2120 };
2121 let prob_arr = batch.column(prob_idx);
2122 let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
2123 for i in 0..vids.len() {
2124 if !vids.is_null(i) {
2125 let p = probs
2126 .and_then(|arr| {
2127 if arr.is_null(i) {
2128 None
2129 } else {
2130 Some(arr.value(i))
2131 }
2132 })
2133 .unwrap_or(0.0);
2134 prob_map
2137 .entry(vids.value(i))
2138 .and_modify(|existing| {
2139 *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
2140 })
2141 .or_insert(p);
2142 }
2143 }
2144 }
2145
2146 let mut result = Vec::new();
2148 for batch in batches {
2149 let Ok(idx) = batch.schema().index_of(left_col) else {
2150 result.push(batch);
2151 continue;
2152 };
2153 let Some(vids) = batch.column(idx).as_any().downcast_ref::<UInt64Array>() else {
2154 result.push(batch);
2155 continue;
2156 };
2157
2158 let complements: Vec<f64> = (0..vids.len())
2160 .map(|i| {
2161 if vids.is_null(i) {
2162 1.0
2163 } else {
2164 let p = prob_map.get(&vids.value(i)).copied().unwrap_or(0.0);
2165 1.0 - p
2166 }
2167 })
2168 .collect();
2169
2170 let complement_arr = Float64Array::from(complements);
2171
2172 let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
2174 columns.push(std::sync::Arc::new(complement_arr));
2175
2176 let mut fields: Vec<std::sync::Arc<arrow_schema::Field>> =
2177 batch.schema().fields().iter().cloned().collect();
2178 fields.push(std::sync::Arc::new(arrow_schema::Field::new(
2179 complement_col_name,
2180 arrow_schema::DataType::Float64,
2181 true,
2182 )));
2183
2184 let new_schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
2185 let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
2186 result.push(new_batch);
2187 }
2188 Ok(result)
2189}
2190
2191pub fn apply_prob_complement_composite(
2198 batches: Vec<RecordBatch>,
2199 neg_facts: &[RecordBatch],
2200 join_cols: &[(String, String)],
2201 prob_col: &str,
2202 complement_col_name: &str,
2203) -> datafusion::error::Result<Vec<RecordBatch>> {
2204 use arrow_array::{Array as _, Float64Array, UInt64Array};
2205
2206 let mut prob_map: HashMap<Vec<u64>, f64> = HashMap::new();
2208 for batch in neg_facts {
2209 let right_indices: Vec<usize> = join_cols
2210 .iter()
2211 .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
2212 .collect();
2213 if right_indices.len() != join_cols.len() {
2214 continue;
2215 }
2216 let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
2217 continue;
2218 };
2219 let prob_arr = batch.column(prob_idx);
2220 let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
2221 for row in 0..batch.num_rows() {
2222 let mut key = Vec::with_capacity(right_indices.len());
2223 let mut valid = true;
2224 for &ci in &right_indices {
2225 let col = batch.column(ci);
2226 if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
2227 if vids.is_null(row) {
2228 valid = false;
2229 break;
2230 }
2231 key.push(vids.value(row));
2232 } else {
2233 valid = false;
2234 break;
2235 }
2236 }
2237 if !valid {
2238 continue;
2239 }
2240 let p = probs
2241 .and_then(|arr| {
2242 if arr.is_null(row) {
2243 None
2244 } else {
2245 Some(arr.value(row))
2246 }
2247 })
2248 .unwrap_or(0.0);
2249 prob_map
2251 .entry(key)
2252 .and_modify(|existing| {
2253 *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
2254 })
2255 .or_insert(p);
2256 }
2257 }
2258
2259 let mut result = Vec::new();
2261 for batch in batches {
2262 let left_indices: Vec<usize> = join_cols
2263 .iter()
2264 .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
2265 .collect();
2266 if left_indices.len() != join_cols.len() {
2267 result.push(batch);
2268 continue;
2269 }
2270 let all_u64 = left_indices.iter().all(|&ci| {
2271 batch
2272 .column(ci)
2273 .as_any()
2274 .downcast_ref::<UInt64Array>()
2275 .is_some()
2276 });
2277 if !all_u64 {
2278 result.push(batch);
2279 continue;
2280 }
2281
2282 let complements: Vec<f64> = (0..batch.num_rows())
2283 .map(|row| {
2284 let mut key = Vec::with_capacity(left_indices.len());
2285 for &ci in &left_indices {
2286 let vids = batch
2287 .column(ci)
2288 .as_any()
2289 .downcast_ref::<UInt64Array>()
2290 .unwrap();
2291 if vids.is_null(row) {
2292 return 1.0;
2293 }
2294 key.push(vids.value(row));
2295 }
2296 let p = prob_map.get(&key).copied().unwrap_or(0.0);
2297 1.0 - p
2298 })
2299 .collect();
2300
2301 let complement_arr = Float64Array::from(complements);
2302 let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
2303 columns.push(Arc::new(complement_arr));
2304
2305 let mut fields: Vec<Arc<arrow_schema::Field>> =
2306 batch.schema().fields().iter().cloned().collect();
2307 fields.push(Arc::new(arrow_schema::Field::new(
2308 complement_col_name,
2309 arrow_schema::DataType::Float64,
2310 true,
2311 )));
2312
2313 let new_schema = Arc::new(arrow_schema::Schema::new(fields));
2314 let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
2315 result.push(new_batch);
2316 }
2317 Ok(result)
2318}
2319
2320pub fn apply_anti_join_composite(
2326 batches: Vec<RecordBatch>,
2327 neg_facts: &[RecordBatch],
2328 join_cols: &[(String, String)],
2329) -> datafusion::error::Result<Vec<RecordBatch>> {
2330 use arrow::compute::filter_record_batch;
2331 use arrow_array::{Array as _, BooleanArray, UInt64Array};
2332
2333 let mut banned: HashSet<Vec<u64>> = HashSet::new();
2335 for batch in neg_facts {
2336 let right_indices: Vec<usize> = join_cols
2337 .iter()
2338 .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
2339 .collect();
2340 if right_indices.len() != join_cols.len() {
2341 continue;
2342 }
2343 for row in 0..batch.num_rows() {
2344 let mut key = Vec::with_capacity(right_indices.len());
2345 let mut valid = true;
2346 for &ci in &right_indices {
2347 let col = batch.column(ci);
2348 if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
2349 if vids.is_null(row) {
2350 valid = false;
2351 break;
2352 }
2353 key.push(vids.value(row));
2354 } else {
2355 valid = false;
2356 break;
2357 }
2358 }
2359 if valid {
2360 banned.insert(key);
2361 }
2362 }
2363 }
2364
2365 if banned.is_empty() {
2366 return Ok(batches);
2367 }
2368
2369 let mut result = Vec::new();
2371 for batch in batches {
2372 let left_indices: Vec<usize> = join_cols
2373 .iter()
2374 .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
2375 .collect();
2376 if left_indices.len() != join_cols.len() {
2377 result.push(batch);
2378 continue;
2379 }
2380 let all_u64 = left_indices.iter().all(|&ci| {
2381 batch
2382 .column(ci)
2383 .as_any()
2384 .downcast_ref::<UInt64Array>()
2385 .is_some()
2386 });
2387 if !all_u64 {
2388 result.push(batch);
2389 continue;
2390 }
2391
2392 let keep: Vec<bool> = (0..batch.num_rows())
2393 .map(|row| {
2394 let mut key = Vec::with_capacity(left_indices.len());
2395 for &ci in &left_indices {
2396 let vids = batch
2397 .column(ci)
2398 .as_any()
2399 .downcast_ref::<UInt64Array>()
2400 .unwrap();
2401 if vids.is_null(row) {
2402 return true; }
2404 key.push(vids.value(row));
2405 }
2406 !banned.contains(&key)
2407 })
2408 .collect();
2409 let keep_arr = BooleanArray::from(keep);
2410 let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
2411 if filtered.num_rows() > 0 {
2412 result.push(filtered);
2413 }
2414 }
2415 Ok(result)
2416}
2417
2418pub fn multiply_prob_factors(
2429 batches: Vec<RecordBatch>,
2430 prob_col: Option<&str>,
2431 complement_cols: &[String],
2432) -> datafusion::error::Result<Vec<RecordBatch>> {
2433 use arrow_array::{Array as _, Float64Array};
2434
2435 let mut result = Vec::with_capacity(batches.len());
2436
2437 for batch in batches {
2438 if batch.num_rows() == 0 {
2439 let keep: Vec<usize> = batch
2441 .schema()
2442 .fields()
2443 .iter()
2444 .enumerate()
2445 .filter(|(_, f)| !complement_cols.contains(f.name()))
2446 .map(|(i, _)| i)
2447 .collect();
2448 let fields: Vec<_> = keep
2449 .iter()
2450 .map(|&i| batch.schema().field(i).clone())
2451 .collect();
2452 let cols: Vec<_> = keep.iter().map(|&i| batch.column(i).clone()).collect();
2453 let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
2454 result.push(
2455 RecordBatch::try_new(schema, cols).map_err(|e| {
2456 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
2457 })?,
2458 );
2459 continue;
2460 }
2461
2462 let num_rows = batch.num_rows();
2463
2464 let mut combined = vec![1.0f64; num_rows];
2466 for col_name in complement_cols {
2467 if let Ok(idx) = batch.schema().index_of(col_name) {
2468 let arr = batch
2469 .column(idx)
2470 .as_any()
2471 .downcast_ref::<Float64Array>()
2472 .ok_or_else(|| {
2473 datafusion::error::DataFusionError::Internal(format!(
2474 "Expected Float64 for complement column {col_name}"
2475 ))
2476 })?;
2477 for (i, val) in combined.iter_mut().enumerate().take(num_rows) {
2478 if !arr.is_null(i) {
2479 *val *= arr.value(i);
2480 }
2481 }
2482 }
2483 }
2484
2485 let final_prob: Vec<f64> = if let Some(prob_name) = prob_col {
2487 if let Ok(idx) = batch.schema().index_of(prob_name) {
2488 let arr = batch
2489 .column(idx)
2490 .as_any()
2491 .downcast_ref::<Float64Array>()
2492 .ok_or_else(|| {
2493 datafusion::error::DataFusionError::Internal(format!(
2494 "Expected Float64 for PROB column {prob_name}"
2495 ))
2496 })?;
2497 (0..num_rows)
2498 .map(|i| {
2499 if arr.is_null(i) {
2500 combined[i]
2501 } else {
2502 arr.value(i) * combined[i]
2503 }
2504 })
2505 .collect()
2506 } else {
2507 combined
2508 }
2509 } else {
2510 combined
2511 };
2512
2513 let new_prob_array: arrow_array::ArrayRef =
2514 std::sync::Arc::new(Float64Array::from(final_prob));
2515
2516 let mut fields = Vec::new();
2518 let mut columns = Vec::new();
2519
2520 for (idx, field) in batch.schema().fields().iter().enumerate() {
2521 if complement_cols.contains(field.name()) {
2522 continue;
2523 }
2524 if prob_col.is_some_and(|p| field.name() == p) {
2525 fields.push(field.clone());
2526 columns.push(new_prob_array.clone());
2527 } else {
2528 fields.push(field.clone());
2529 columns.push(batch.column(idx).clone());
2530 }
2531 }
2532
2533 let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
2534 result.push(RecordBatch::try_new(schema, columns).map_err(arrow_err)?);
2535 }
2536
2537 Ok(result)
2538}
2539
2540fn update_derived_scan_handles(
2545 registry: &DerivedScanRegistry,
2546 states: &[FixpointState],
2547 current_rule_idx: usize,
2548 rules: &[FixpointRulePlan],
2549) {
2550 let current_rule_name = &rules[current_rule_idx].name;
2551
2552 for entry in ®istry.entries {
2553 let source_state_idx = rules.iter().position(|r| r.name == entry.rule_name);
2555 let Some(source_idx) = source_state_idx else {
2556 continue;
2557 };
2558
2559 let is_self = entry.rule_name == *current_rule_name;
2560 let data = if is_self {
2561 states[source_idx].all_delta().to_vec()
2563 } else {
2564 states[source_idx].all_facts().to_vec()
2566 };
2567
2568 let data = if data.is_empty() || data.iter().all(|b| b.num_rows() == 0) {
2570 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
2571 } else {
2572 data
2573 };
2574
2575 let mut guard = entry.data.write();
2576 *guard = data;
2577 }
2578}
2579
2580pub struct DerivedScanExec {
2590 data: Arc<RwLock<Vec<RecordBatch>>>,
2591 schema: SchemaRef,
2592 properties: PlanProperties,
2593}
2594
2595impl DerivedScanExec {
2596 pub fn new(data: Arc<RwLock<Vec<RecordBatch>>>, schema: SchemaRef) -> Self {
2597 let properties = compute_plan_properties(Arc::clone(&schema));
2598 Self {
2599 data,
2600 schema,
2601 properties,
2602 }
2603 }
2604}
2605
2606impl fmt::Debug for DerivedScanExec {
2607 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2608 f.debug_struct("DerivedScanExec")
2609 .field("schema", &self.schema)
2610 .finish()
2611 }
2612}
2613
2614impl DisplayAs for DerivedScanExec {
2615 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2616 write!(f, "DerivedScanExec")
2617 }
2618}
2619
2620impl ExecutionPlan for DerivedScanExec {
2621 fn name(&self) -> &str {
2622 "DerivedScanExec"
2623 }
2624 fn as_any(&self) -> &dyn Any {
2625 self
2626 }
2627 fn schema(&self) -> SchemaRef {
2628 Arc::clone(&self.schema)
2629 }
2630 fn properties(&self) -> &PlanProperties {
2631 &self.properties
2632 }
2633 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2634 vec![]
2635 }
2636 fn with_new_children(
2637 self: Arc<Self>,
2638 _children: Vec<Arc<dyn ExecutionPlan>>,
2639 ) -> DFResult<Arc<dyn ExecutionPlan>> {
2640 Ok(self)
2641 }
2642 fn execute(
2643 &self,
2644 _partition: usize,
2645 _context: Arc<TaskContext>,
2646 ) -> DFResult<SendableRecordBatchStream> {
2647 let batches = {
2648 let guard = self.data.read();
2649 if guard.is_empty() {
2650 vec![RecordBatch::new_empty(Arc::clone(&self.schema))]
2651 } else {
2652 guard.clone()
2653 }
2654 };
2655 Ok(Box::pin(MemoryStream::try_new(
2656 batches,
2657 Arc::clone(&self.schema),
2658 None,
2659 )?))
2660 }
2661}
2662
2663struct InMemoryExec {
2672 batches: Vec<RecordBatch>,
2673 schema: SchemaRef,
2674 properties: PlanProperties,
2675}
2676
2677impl InMemoryExec {
2678 fn new(batches: Vec<RecordBatch>, schema: SchemaRef) -> Self {
2679 let properties = compute_plan_properties(Arc::clone(&schema));
2680 Self {
2681 batches,
2682 schema,
2683 properties,
2684 }
2685 }
2686}
2687
2688impl fmt::Debug for InMemoryExec {
2689 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2690 f.debug_struct("InMemoryExec")
2691 .field("num_batches", &self.batches.len())
2692 .field("schema", &self.schema)
2693 .finish()
2694 }
2695}
2696
2697impl DisplayAs for InMemoryExec {
2698 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2699 write!(f, "InMemoryExec: batches={}", self.batches.len())
2700 }
2701}
2702
2703impl ExecutionPlan for InMemoryExec {
2704 fn name(&self) -> &str {
2705 "InMemoryExec"
2706 }
2707 fn as_any(&self) -> &dyn Any {
2708 self
2709 }
2710 fn schema(&self) -> SchemaRef {
2711 Arc::clone(&self.schema)
2712 }
2713 fn properties(&self) -> &PlanProperties {
2714 &self.properties
2715 }
2716 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2717 vec![]
2718 }
2719 fn with_new_children(
2720 self: Arc<Self>,
2721 _children: Vec<Arc<dyn ExecutionPlan>>,
2722 ) -> DFResult<Arc<dyn ExecutionPlan>> {
2723 Ok(self)
2724 }
2725 fn execute(
2726 &self,
2727 _partition: usize,
2728 _context: Arc<TaskContext>,
2729 ) -> DFResult<SendableRecordBatchStream> {
2730 Ok(Box::pin(MemoryStream::try_new(
2731 self.batches.clone(),
2732 Arc::clone(&self.schema),
2733 None,
2734 )?))
2735 }
2736}
2737
2738fn apply_having_filter(
2748 batches: Vec<RecordBatch>,
2749 having_exprs: &[Expr],
2750 schema: &SchemaRef,
2751) -> DFResult<Vec<RecordBatch>> {
2752 use arrow::compute::{and, filter_record_batch};
2753 use arrow_array::BooleanArray;
2754 use datafusion::common::DFSchema;
2755 use datafusion::logical_expr::LogicalPlanBuilder;
2756 use datafusion::optimizer::AnalyzerRule;
2757 use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
2758 use datafusion::physical_expr::create_physical_expr;
2759 use datafusion::prelude::SessionContext;
2760
2761 if batches.is_empty() {
2762 return Ok(batches);
2763 }
2764
2765 let df_schema = DFSchema::try_from(schema.as_ref().clone()).map_err(|e| {
2767 datafusion::common::DataFusionError::Internal(format!("HAVING schema conversion: {e}"))
2768 })?;
2769
2770 let ctx = SessionContext::new();
2771 let state = ctx.state();
2772 let config = state.config_options().clone();
2773 let props = state.execution_props();
2774
2775 let physical_exprs: Vec<Arc<dyn datafusion::physical_expr::PhysicalExpr>> = having_exprs
2781 .iter()
2782 .map(|expr| {
2783 let df_expr = crate::query::df_expr::cypher_expr_to_df(expr, None).map_err(|e| {
2784 datafusion::common::DataFusionError::Internal(format!(
2785 "HAVING expression conversion: {e}"
2786 ))
2787 })?;
2788
2789 let empty = datafusion::logical_expr::LogicalPlan::EmptyRelation(
2793 datafusion::logical_expr::EmptyRelation {
2794 produce_one_row: false,
2795 schema: Arc::new(df_schema.clone()),
2796 },
2797 );
2798 let filter_plan = LogicalPlanBuilder::from(empty)
2799 .filter(df_expr.clone())?
2800 .build()?;
2801 let coerced_expr = match TypeCoercion::new().analyze(filter_plan, &config) {
2802 Ok(datafusion::logical_expr::LogicalPlan::Filter(f)) => f.predicate,
2803 _ => df_expr,
2804 };
2805
2806 create_physical_expr(&coerced_expr, &df_schema, props)
2807 })
2808 .collect::<DFResult<Vec<_>>>()?;
2809
2810 let mut result = Vec::new();
2811 for batch in batches {
2812 let mut mask: Option<BooleanArray> = None;
2814 for phys_expr in &physical_exprs {
2815 let value = phys_expr.evaluate(&batch)?;
2816 let arr = value.into_array(batch.num_rows())?;
2817 let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().ok_or_else(|| {
2818 datafusion::common::DataFusionError::Internal(
2819 "HAVING condition must evaluate to boolean".into(),
2820 )
2821 })?;
2822 mask = Some(match mask {
2823 None => bool_arr.clone(),
2824 Some(prev) => and(&prev, bool_arr).map_err(arrow_err)?,
2825 });
2826 }
2827 if let Some(ref m) = mask {
2828 let filtered = filter_record_batch(&batch, m).map_err(arrow_err)?;
2829 if filtered.num_rows() > 0 {
2830 result.push(filtered);
2831 }
2832 } else {
2833 result.push(batch);
2834 }
2835 }
2836 Ok(result)
2837}
2838
2839pub(crate) async fn apply_post_fixpoint_chain(
2841 facts: Vec<RecordBatch>,
2842 rule: &FixpointRulePlan,
2843 task_ctx: &Arc<TaskContext>,
2844 strict_probability_domain: bool,
2845 probability_epsilon: f64,
2846) -> DFResult<Vec<RecordBatch>> {
2847 if !rule.has_fold && !rule.has_best_by && !rule.has_priority && rule.having.is_empty() {
2848 return Ok(facts);
2849 }
2850
2851 let schema = facts
2856 .iter()
2857 .find(|b| b.num_rows() > 0)
2858 .map(|b| b.schema())
2859 .unwrap_or_else(|| Arc::clone(&rule.yield_schema));
2860 let input: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(facts, schema.clone()));
2861
2862 let key_column_indices: Vec<usize> = rule
2867 .key_column_indices
2868 .iter()
2869 .filter_map(|&i| {
2870 let name = rule.yield_schema.field(i).name();
2871 schema.index_of(name).ok()
2872 })
2873 .collect();
2874
2875 let current: Arc<dyn ExecutionPlan> = if rule.has_priority {
2879 let priority_schema = input.schema();
2880 let priority_idx = priority_schema.index_of("__priority").map_err(|_| {
2881 datafusion::common::DataFusionError::Internal(
2882 "PRIORITY rule missing __priority column".to_string(),
2883 )
2884 })?;
2885 Arc::new(PriorityExec::new(
2886 input,
2887 key_column_indices.clone(),
2888 priority_idx,
2889 ))
2890 } else {
2891 input
2892 };
2893
2894 let current: Arc<dyn ExecutionPlan> = if rule.has_fold && !rule.fold_bindings.is_empty() {
2896 Arc::new(FoldExec::new(
2897 current,
2898 key_column_indices.clone(),
2899 rule.fold_bindings.clone(),
2900 strict_probability_domain,
2901 probability_epsilon,
2902 ))
2903 } else {
2904 current
2905 };
2906
2907 let current: Arc<dyn ExecutionPlan> = if !rule.having.is_empty() {
2909 let batches = collect_all_partitions(¤t, Arc::clone(task_ctx)).await?;
2910 let filtered = apply_having_filter(batches, &rule.having, ¤t.schema())?;
2911 if filtered.is_empty() {
2912 return Ok(filtered);
2913 }
2914 Arc::new(InMemoryExec::new(filtered, Arc::clone(¤t.schema())))
2915 } else {
2916 current
2917 };
2918
2919 let current: Arc<dyn ExecutionPlan> = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
2921 Arc::new(BestByExec::new(
2922 current,
2923 key_column_indices.clone(),
2924 rule.best_by_criteria.clone(),
2925 rule.deterministic,
2926 ))
2927 } else {
2928 current
2929 };
2930
2931 collect_all_partitions(¤t, Arc::clone(task_ctx)).await
2932}
2933
2934pub struct FixpointExec {
2943 rules: Vec<FixpointRulePlan>,
2944 max_iterations: usize,
2945 timeout: Duration,
2946 graph_ctx: Arc<GraphExecutionContext>,
2947 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
2948 storage: Arc<StorageManager>,
2949 schema_info: Arc<UniSchema>,
2950 params: HashMap<String, Value>,
2951 derived_scan_registry: Arc<DerivedScanRegistry>,
2952 output_schema: SchemaRef,
2953 properties: PlanProperties,
2954 metrics: ExecutionPlanMetricsSet,
2955 max_derived_bytes: usize,
2956 derivation_tracker: Option<Arc<ProvenanceStore>>,
2958 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
2960 strict_probability_domain: bool,
2961 probability_epsilon: f64,
2962 exact_probability: bool,
2963 max_bdd_variables: usize,
2964 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
2966 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
2968 top_k_proofs: usize,
2970 timeout_flag: Arc<std::sync::atomic::AtomicBool>,
2972}
2973
2974impl fmt::Debug for FixpointExec {
2975 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2976 f.debug_struct("FixpointExec")
2977 .field("rules_count", &self.rules.len())
2978 .field("max_iterations", &self.max_iterations)
2979 .field("timeout", &self.timeout)
2980 .field("output_schema", &self.output_schema)
2981 .field("max_derived_bytes", &self.max_derived_bytes)
2982 .finish_non_exhaustive()
2983 }
2984}
2985
2986impl FixpointExec {
2987 #[expect(
2989 clippy::too_many_arguments,
2990 reason = "FixpointExec configuration needs all context"
2991 )]
2992 pub fn new(
2993 rules: Vec<FixpointRulePlan>,
2994 max_iterations: usize,
2995 timeout: Duration,
2996 graph_ctx: Arc<GraphExecutionContext>,
2997 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
2998 storage: Arc<StorageManager>,
2999 schema_info: Arc<UniSchema>,
3000 params: HashMap<String, Value>,
3001 derived_scan_registry: Arc<DerivedScanRegistry>,
3002 output_schema: SchemaRef,
3003 max_derived_bytes: usize,
3004 derivation_tracker: Option<Arc<ProvenanceStore>>,
3005 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
3006 strict_probability_domain: bool,
3007 probability_epsilon: f64,
3008 exact_probability: bool,
3009 max_bdd_variables: usize,
3010 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
3011 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
3012 top_k_proofs: usize,
3013 timeout_flag: Arc<std::sync::atomic::AtomicBool>,
3014 ) -> Self {
3015 let properties = compute_plan_properties(Arc::clone(&output_schema));
3016 Self {
3017 rules,
3018 max_iterations,
3019 timeout,
3020 graph_ctx,
3021 session_ctx,
3022 storage,
3023 schema_info,
3024 params,
3025 derived_scan_registry,
3026 output_schema,
3027 properties,
3028 metrics: ExecutionPlanMetricsSet::new(),
3029 max_derived_bytes,
3030 derivation_tracker,
3031 iteration_counts,
3032 strict_probability_domain,
3033 probability_epsilon,
3034 exact_probability,
3035 max_bdd_variables,
3036 warnings_slot,
3037 approximate_slot,
3038 top_k_proofs,
3039 timeout_flag,
3040 }
3041 }
3042
3043 pub fn iteration_counts(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
3045 Arc::clone(&self.iteration_counts)
3046 }
3047}
3048
3049impl DisplayAs for FixpointExec {
3050 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3051 write!(
3052 f,
3053 "FixpointExec: rules=[{}], max_iter={}, timeout={:?}",
3054 self.rules
3055 .iter()
3056 .map(|r| r.name.as_str())
3057 .collect::<Vec<_>>()
3058 .join(", "),
3059 self.max_iterations,
3060 self.timeout,
3061 )
3062 }
3063}
3064
3065impl ExecutionPlan for FixpointExec {
3066 fn name(&self) -> &str {
3067 "FixpointExec"
3068 }
3069
3070 fn as_any(&self) -> &dyn Any {
3071 self
3072 }
3073
3074 fn schema(&self) -> SchemaRef {
3075 Arc::clone(&self.output_schema)
3076 }
3077
3078 fn properties(&self) -> &PlanProperties {
3079 &self.properties
3080 }
3081
3082 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
3083 vec![]
3085 }
3086
3087 fn with_new_children(
3088 self: Arc<Self>,
3089 children: Vec<Arc<dyn ExecutionPlan>>,
3090 ) -> DFResult<Arc<dyn ExecutionPlan>> {
3091 if !children.is_empty() {
3092 return Err(datafusion::error::DataFusionError::Plan(
3093 "FixpointExec has no children".to_string(),
3094 ));
3095 }
3096 Ok(self)
3097 }
3098
3099 fn execute(
3100 &self,
3101 partition: usize,
3102 _context: Arc<TaskContext>,
3103 ) -> DFResult<SendableRecordBatchStream> {
3104 let metrics = BaselineMetrics::new(&self.metrics, partition);
3105
3106 let rules = self
3108 .rules
3109 .iter()
3110 .map(|r| {
3111 FixpointRulePlan {
3115 name: r.name.clone(),
3116 clauses: r
3117 .clauses
3118 .iter()
3119 .map(|c| FixpointClausePlan {
3120 body_logical: c.body_logical.clone(),
3121 is_ref_bindings: c.is_ref_bindings.clone(),
3122 priority: c.priority,
3123 along_bindings: c.along_bindings.clone(),
3124 })
3125 .collect(),
3126 yield_schema: Arc::clone(&r.yield_schema),
3127 key_column_indices: r.key_column_indices.clone(),
3128 priority: r.priority,
3129 has_fold: r.has_fold,
3130 fold_bindings: r.fold_bindings.clone(),
3131 having: r.having.clone(),
3132 has_best_by: r.has_best_by,
3133 best_by_criteria: r.best_by_criteria.clone(),
3134 has_priority: r.has_priority,
3135 deterministic: r.deterministic,
3136 prob_column_name: r.prob_column_name.clone(),
3137 }
3138 })
3139 .collect();
3140
3141 let max_iterations = self.max_iterations;
3142 let timeout = self.timeout;
3143 let graph_ctx = Arc::clone(&self.graph_ctx);
3144 let session_ctx = Arc::clone(&self.session_ctx);
3145 let storage = Arc::clone(&self.storage);
3146 let schema_info = Arc::clone(&self.schema_info);
3147 let params = self.params.clone();
3148 let registry = Arc::clone(&self.derived_scan_registry);
3149 let output_schema = Arc::clone(&self.output_schema);
3150 let max_derived_bytes = self.max_derived_bytes;
3151 let derivation_tracker = self.derivation_tracker.clone();
3152 let iteration_counts = Arc::clone(&self.iteration_counts);
3153 let strict_probability_domain = self.strict_probability_domain;
3154 let probability_epsilon = self.probability_epsilon;
3155 let exact_probability = self.exact_probability;
3156 let max_bdd_variables = self.max_bdd_variables;
3157 let warnings_slot = Arc::clone(&self.warnings_slot);
3158 let approximate_slot = Arc::clone(&self.approximate_slot);
3159 let top_k_proofs = self.top_k_proofs;
3160 let timeout_flag = Arc::clone(&self.timeout_flag);
3161
3162 let fut = async move {
3163 run_fixpoint_loop(
3164 rules,
3165 max_iterations,
3166 timeout,
3167 graph_ctx,
3168 session_ctx,
3169 storage,
3170 schema_info,
3171 params,
3172 registry,
3173 output_schema,
3174 max_derived_bytes,
3175 derivation_tracker,
3176 iteration_counts,
3177 strict_probability_domain,
3178 probability_epsilon,
3179 exact_probability,
3180 max_bdd_variables,
3181 warnings_slot,
3182 approximate_slot,
3183 top_k_proofs,
3184 timeout_flag,
3185 )
3186 .await
3187 };
3188
3189 Ok(Box::pin(FixpointStream {
3190 state: FixpointStreamState::Running(Box::pin(fut)),
3191 schema: Arc::clone(&self.output_schema),
3192 metrics,
3193 }))
3194 }
3195
3196 fn metrics(&self) -> Option<MetricsSet> {
3197 Some(self.metrics.clone_inner())
3198 }
3199}
3200
3201enum FixpointStreamState {
3206 Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
3208 Emitting(Vec<RecordBatch>, usize),
3210 Done,
3212}
3213
3214struct FixpointStream {
3215 state: FixpointStreamState,
3216 schema: SchemaRef,
3217 metrics: BaselineMetrics,
3218}
3219
3220impl Stream for FixpointStream {
3221 type Item = DFResult<RecordBatch>;
3222
3223 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
3224 let this = self.get_mut();
3225 loop {
3226 match &mut this.state {
3227 FixpointStreamState::Running(fut) => match fut.as_mut().poll(cx) {
3228 Poll::Ready(Ok(batches)) => {
3229 if batches.is_empty() {
3230 this.state = FixpointStreamState::Done;
3231 return Poll::Ready(None);
3232 }
3233 this.state = FixpointStreamState::Emitting(batches, 0);
3234 }
3236 Poll::Ready(Err(e)) => {
3237 this.state = FixpointStreamState::Done;
3238 return Poll::Ready(Some(Err(e)));
3239 }
3240 Poll::Pending => return Poll::Pending,
3241 },
3242 FixpointStreamState::Emitting(batches, idx) => {
3243 if *idx >= batches.len() {
3244 this.state = FixpointStreamState::Done;
3245 return Poll::Ready(None);
3246 }
3247 let batch = batches[*idx].clone();
3248 *idx += 1;
3249 this.metrics.record_output(batch.num_rows());
3250 return Poll::Ready(Some(Ok(batch)));
3251 }
3252 FixpointStreamState::Done => return Poll::Ready(None),
3253 }
3254 }
3255 }
3256}
3257
3258impl RecordBatchStream for FixpointStream {
3259 fn schema(&self) -> SchemaRef {
3260 Arc::clone(&self.schema)
3261 }
3262}
3263
3264#[cfg(test)]
3269mod tests {
3270 use super::*;
3271 use arrow_array::{Float64Array, Int64Array, StringArray};
3272 use arrow_schema::{DataType, Field, Schema};
3273
3274 fn test_schema() -> SchemaRef {
3275 Arc::new(Schema::new(vec![
3276 Field::new("name", DataType::Utf8, true),
3277 Field::new("value", DataType::Int64, true),
3278 ]))
3279 }
3280
3281 fn make_batch(names: &[&str], values: &[i64]) -> RecordBatch {
3282 RecordBatch::try_new(
3283 test_schema(),
3284 vec![
3285 Arc::new(StringArray::from(
3286 names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
3287 )),
3288 Arc::new(Int64Array::from(values.to_vec())),
3289 ],
3290 )
3291 .unwrap()
3292 }
3293
3294 #[tokio::test]
3297 async fn test_fixpoint_state_empty_facts_adds_all() {
3298 let schema = test_schema();
3299 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3300
3301 let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
3302 let changed = state.merge_delta(vec![batch], None).await.unwrap();
3303
3304 assert!(changed);
3305 assert_eq!(state.all_facts().len(), 1);
3306 assert_eq!(state.all_facts()[0].num_rows(), 3);
3307 assert_eq!(state.all_delta().len(), 1);
3308 assert_eq!(state.all_delta()[0].num_rows(), 3);
3309 }
3310
3311 #[tokio::test]
3312 async fn test_fixpoint_state_exact_duplicates_excluded() {
3313 let schema = test_schema();
3314 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3315
3316 let batch1 = make_batch(&["a", "b"], &[1, 2]);
3317 state.merge_delta(vec![batch1], None).await.unwrap();
3318
3319 let batch2 = make_batch(&["a", "b"], &[1, 2]);
3321 let changed = state.merge_delta(vec![batch2], None).await.unwrap();
3322 assert!(!changed);
3323 assert!(
3324 state.all_delta().is_empty() || state.all_delta().iter().all(|b| b.num_rows() == 0)
3325 );
3326 }
3327
3328 #[tokio::test]
3329 async fn test_fixpoint_state_partial_overlap() {
3330 let schema = test_schema();
3331 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3332
3333 let batch1 = make_batch(&["a", "b"], &[1, 2]);
3334 state.merge_delta(vec![batch1], None).await.unwrap();
3335
3336 let batch2 = make_batch(&["a", "c"], &[1, 3]);
3338 let changed = state.merge_delta(vec![batch2], None).await.unwrap();
3339 assert!(changed);
3340
3341 let delta_rows: usize = state.all_delta().iter().map(|b| b.num_rows()).sum();
3343 assert_eq!(delta_rows, 1);
3344
3345 let total_rows: usize = state.all_facts().iter().map(|b| b.num_rows()).sum();
3347 assert_eq!(total_rows, 3);
3348 }
3349
3350 #[tokio::test]
3351 async fn test_fixpoint_state_convergence() {
3352 let schema = test_schema();
3353 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3354
3355 let batch = make_batch(&["a"], &[1]);
3356 state.merge_delta(vec![batch], None).await.unwrap();
3357
3358 let changed = state.merge_delta(vec![], None).await.unwrap();
3360 assert!(!changed);
3361 assert!(state.is_converged());
3362 }
3363
3364 #[test]
3367 fn test_row_dedup_persistent_across_calls() {
3368 let schema = test_schema();
3371 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
3372
3373 let batch1 = make_batch(&["a", "b"], &[1, 2]);
3374 let delta1 = rd.compute_delta(&[batch1], &schema).unwrap();
3375 let rows1: usize = delta1.iter().map(|b| b.num_rows()).sum();
3377 assert_eq!(rows1, 2);
3378
3379 let batch2 = make_batch(&["a", "b"], &[1, 2]);
3381 let delta2 = rd.compute_delta(&[batch2], &schema).unwrap();
3382 let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
3383 assert_eq!(rows2, 0);
3384
3385 let batch3 = make_batch(&["a", "c"], &[1, 3]);
3387 let delta3 = rd.compute_delta(&[batch3], &schema).unwrap();
3388 let rows3: usize = delta3.iter().map(|b| b.num_rows()).sum();
3389 assert_eq!(rows3, 1);
3390 }
3391
3392 #[test]
3393 fn test_row_dedup_null_handling() {
3394 use arrow_array::StringArray;
3395 use arrow_schema::{DataType, Field, Schema};
3396
3397 let schema: SchemaRef = Arc::new(Schema::new(vec![
3398 Field::new("a", DataType::Utf8, true),
3399 Field::new("b", DataType::Int64, true),
3400 ]));
3401 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
3402
3403 let batch_nulls = RecordBatch::try_new(
3405 Arc::clone(&schema),
3406 vec![
3407 Arc::new(StringArray::from(vec![None::<&str>, None::<&str>])),
3408 Arc::new(arrow_array::Int64Array::from(vec![1i64, 1i64])),
3409 ],
3410 )
3411 .unwrap();
3412 let delta = rd.compute_delta(&[batch_nulls], &schema).unwrap();
3413 let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
3414 assert_eq!(rows, 1, "two identical NULL rows should be deduped to one");
3415
3416 let batch_diff = RecordBatch::try_new(
3418 Arc::clone(&schema),
3419 vec![
3420 Arc::new(StringArray::from(vec![None::<&str>])),
3421 Arc::new(arrow_array::Int64Array::from(vec![2i64])),
3422 ],
3423 )
3424 .unwrap();
3425 let delta2 = rd.compute_delta(&[batch_diff], &schema).unwrap();
3426 let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
3427 assert_eq!(rows2, 1, "(NULL, 2) is distinct from (NULL, 1)");
3428 }
3429
3430 #[test]
3431 fn test_row_dedup_within_candidate_dedup() {
3432 let schema = test_schema();
3434 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
3435
3436 let batch = make_batch(&["a", "a", "b"], &[1, 1, 2]);
3438 let delta = rd.compute_delta(&[batch], &schema).unwrap();
3439 let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
3440 assert_eq!(rows, 2, "within-batch dup should be collapsed: a:1, b:2");
3441 }
3442
3443 #[test]
3446 fn test_round_float_columns_near_duplicates() {
3447 let schema = Arc::new(Schema::new(vec![
3448 Field::new("name", DataType::Utf8, true),
3449 Field::new("dist", DataType::Float64, true),
3450 ]));
3451 let batch = RecordBatch::try_new(
3452 Arc::clone(&schema),
3453 vec![
3454 Arc::new(StringArray::from(vec![Some("a"), Some("a")])),
3455 Arc::new(Float64Array::from(vec![1.0000000000001, 1.0000000000002])),
3456 ],
3457 )
3458 .unwrap();
3459
3460 let rounded = round_float_columns(&[batch]);
3461 assert_eq!(rounded.len(), 1);
3462 let col = rounded[0]
3463 .column(1)
3464 .as_any()
3465 .downcast_ref::<Float64Array>()
3466 .unwrap();
3467 assert_eq!(col.value(0), col.value(1));
3469 }
3470
3471 #[test]
3474 fn test_registry_write_read_round_trip() {
3475 let schema = test_schema();
3476 let data = Arc::new(RwLock::new(Vec::new()));
3477 let mut reg = DerivedScanRegistry::new();
3478 reg.add(DerivedScanEntry {
3479 scan_index: 0,
3480 rule_name: "reachable".into(),
3481 is_self_ref: true,
3482 data: Arc::clone(&data),
3483 schema: Arc::clone(&schema),
3484 });
3485
3486 let batch = make_batch(&["x"], &[42]);
3487 reg.write_data(0, vec![batch.clone()]);
3488
3489 let entry = reg.get(0).unwrap();
3490 let guard = entry.data.read();
3491 assert_eq!(guard.len(), 1);
3492 assert_eq!(guard[0].num_rows(), 1);
3493 }
3494
3495 #[test]
3496 fn test_registry_entries_for_rule() {
3497 let schema = test_schema();
3498 let mut reg = DerivedScanRegistry::new();
3499 reg.add(DerivedScanEntry {
3500 scan_index: 0,
3501 rule_name: "r1".into(),
3502 is_self_ref: true,
3503 data: Arc::new(RwLock::new(Vec::new())),
3504 schema: Arc::clone(&schema),
3505 });
3506 reg.add(DerivedScanEntry {
3507 scan_index: 1,
3508 rule_name: "r2".into(),
3509 is_self_ref: false,
3510 data: Arc::new(RwLock::new(Vec::new())),
3511 schema: Arc::clone(&schema),
3512 });
3513 reg.add(DerivedScanEntry {
3514 scan_index: 2,
3515 rule_name: "r1".into(),
3516 is_self_ref: false,
3517 data: Arc::new(RwLock::new(Vec::new())),
3518 schema: Arc::clone(&schema),
3519 });
3520
3521 assert_eq!(reg.entries_for_rule("r1").len(), 2);
3522 assert_eq!(reg.entries_for_rule("r2").len(), 1);
3523 assert_eq!(reg.entries_for_rule("r3").len(), 0);
3524 }
3525
3526 #[test]
3529 fn test_monotonic_agg_update_and_stability() {
3530 use crate::query::df_graph::locy_fold::FoldAggKind;
3531
3532 let bindings = vec![MonotonicFoldBinding {
3533 fold_name: "total".into(),
3534 kind: FoldAggKind::Sum,
3535 input_col_index: 1,
3536 input_col_name: None,
3537 }];
3538 let mut agg = MonotonicAggState::new(bindings);
3539
3540 let batch = make_batch(&["a"], &[10]);
3542 agg.snapshot();
3543 let changed = agg.update(&[0], &[batch], false).unwrap();
3544 assert!(changed);
3545 assert!(!agg.is_stable()); agg.snapshot();
3549 let changed = agg.update(&[0], &[], false).unwrap();
3550 assert!(!changed);
3551 assert!(agg.is_stable());
3552 }
3553
3554 #[tokio::test]
3557 async fn test_memory_limit_exceeded() {
3558 let schema = test_schema();
3559 let mut state = FixpointState::new("test".into(), schema, vec![0], 1, None, false);
3561
3562 let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
3563 let result = state.merge_delta(vec![batch], None).await;
3564 assert!(result.is_err());
3565 let err = result.unwrap_err().to_string();
3566 assert!(err.contains("memory limit"), "Error was: {}", err);
3567 }
3568
3569 #[tokio::test]
3572 async fn test_fixpoint_stream_emitting() {
3573 use futures::StreamExt;
3574
3575 let schema = test_schema();
3576 let batch1 = make_batch(&["a"], &[1]);
3577 let batch2 = make_batch(&["b"], &[2]);
3578
3579 let metrics = ExecutionPlanMetricsSet::new();
3580 let baseline = BaselineMetrics::new(&metrics, 0);
3581
3582 let mut stream = FixpointStream {
3583 state: FixpointStreamState::Emitting(vec![batch1, batch2], 0),
3584 schema,
3585 metrics: baseline,
3586 };
3587
3588 let stream = Pin::new(&mut stream);
3589 let batches: Vec<RecordBatch> = stream.filter_map(|r| async { r.ok() }).collect().await;
3590
3591 assert_eq!(batches.len(), 2);
3592 assert_eq!(batches[0].num_rows(), 1);
3593 assert_eq!(batches[1].num_rows(), 1);
3594 }
3595
3596 fn make_f64_batch(names: &[&str], values: &[f64]) -> RecordBatch {
3599 let schema = Arc::new(Schema::new(vec![
3600 Field::new("name", DataType::Utf8, true),
3601 Field::new("value", DataType::Float64, true),
3602 ]));
3603 RecordBatch::try_new(
3604 schema,
3605 vec![
3606 Arc::new(StringArray::from(
3607 names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
3608 )),
3609 Arc::new(Float64Array::from(values.to_vec())),
3610 ],
3611 )
3612 .unwrap()
3613 }
3614
3615 fn make_nor_binding() -> Vec<MonotonicFoldBinding> {
3616 use crate::query::df_graph::locy_fold::FoldAggKind;
3617 vec![MonotonicFoldBinding {
3618 fold_name: "prob".into(),
3619 kind: FoldAggKind::Nor,
3620 input_col_index: 1,
3621 input_col_name: None,
3622 }]
3623 }
3624
3625 fn make_prod_binding() -> Vec<MonotonicFoldBinding> {
3626 use crate::query::df_graph::locy_fold::FoldAggKind;
3627 vec![MonotonicFoldBinding {
3628 fold_name: "prob".into(),
3629 kind: FoldAggKind::Prod,
3630 input_col_index: 1,
3631 input_col_name: None,
3632 }]
3633 }
3634
3635 fn acc_key(name: &str) -> (Vec<ScalarKey>, String) {
3636 (vec![ScalarKey::Utf8(name.to_string())], "prob".to_string())
3637 }
3638
3639 #[test]
3640 fn test_monotonic_nor_first_update() {
3641 let mut agg = MonotonicAggState::new(make_nor_binding());
3642 let batch = make_f64_batch(&["a"], &[0.3]);
3643 let changed = agg.update(&[0], &[batch], false).unwrap();
3644 assert!(changed);
3645 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3646 assert!((val - 0.3).abs() < 1e-10, "expected 0.3, got {}", val);
3647 }
3648
3649 #[test]
3650 fn test_monotonic_nor_two_updates() {
3651 let mut agg = MonotonicAggState::new(make_nor_binding());
3653 let batch1 = make_f64_batch(&["a"], &[0.3]);
3654 agg.update(&[0], &[batch1], false).unwrap();
3655 let batch2 = make_f64_batch(&["a"], &[0.5]);
3656 agg.update(&[0], &[batch2], false).unwrap();
3657 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3658 assert!((val - 0.65).abs() < 1e-10, "expected 0.65, got {}", val);
3659 }
3660
3661 #[test]
3662 fn test_monotonic_prod_first_update() {
3663 let mut agg = MonotonicAggState::new(make_prod_binding());
3664 let batch = make_f64_batch(&["a"], &[0.6]);
3665 let changed = agg.update(&[0], &[batch], false).unwrap();
3666 assert!(changed);
3667 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3668 assert!((val - 0.6).abs() < 1e-10, "expected 0.6, got {}", val);
3669 }
3670
3671 #[test]
3672 fn test_monotonic_prod_two_updates() {
3673 let mut agg = MonotonicAggState::new(make_prod_binding());
3675 let batch1 = make_f64_batch(&["a"], &[0.6]);
3676 agg.update(&[0], &[batch1], false).unwrap();
3677 let batch2 = make_f64_batch(&["a"], &[0.8]);
3678 agg.update(&[0], &[batch2], false).unwrap();
3679 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3680 assert!((val - 0.48).abs() < 1e-10, "expected 0.48, got {}", val);
3681 }
3682
3683 #[test]
3684 fn test_monotonic_nor_stability() {
3685 let mut agg = MonotonicAggState::new(make_nor_binding());
3686 let batch = make_f64_batch(&["a"], &[0.3]);
3687 agg.update(&[0], &[batch], false).unwrap();
3688 agg.snapshot();
3689 let changed = agg.update(&[0], &[], false).unwrap();
3690 assert!(!changed);
3691 assert!(agg.is_stable());
3692 }
3693
3694 #[test]
3695 fn test_monotonic_prod_stability() {
3696 let mut agg = MonotonicAggState::new(make_prod_binding());
3697 let batch = make_f64_batch(&["a"], &[0.6]);
3698 agg.update(&[0], &[batch], false).unwrap();
3699 agg.snapshot();
3700 let changed = agg.update(&[0], &[], false).unwrap();
3701 assert!(!changed);
3702 assert!(agg.is_stable());
3703 }
3704
3705 #[test]
3706 fn test_monotonic_nor_multi_group() {
3707 let mut agg = MonotonicAggState::new(make_nor_binding());
3709 let batch1 = make_f64_batch(&["a", "b"], &[0.3, 0.5]);
3710 agg.update(&[0], &[batch1], false).unwrap();
3711 let batch2 = make_f64_batch(&["a", "b"], &[0.5, 0.2]);
3712 agg.update(&[0], &[batch2], false).unwrap();
3713
3714 let val_a = agg.get_accumulator(&acc_key("a")).unwrap();
3715 let val_b = agg.get_accumulator(&acc_key("b")).unwrap();
3716 assert!(
3717 (val_a - 0.65).abs() < 1e-10,
3718 "expected a=0.65, got {}",
3719 val_a
3720 );
3721 assert!((val_b - 0.6).abs() < 1e-10, "expected b=0.6, got {}", val_b);
3722 }
3723
3724 #[test]
3725 fn test_monotonic_prod_zero_absorbing() {
3726 let mut agg = MonotonicAggState::new(make_prod_binding());
3728 let batch1 = make_f64_batch(&["a"], &[0.5]);
3729 agg.update(&[0], &[batch1], false).unwrap();
3730 let batch2 = make_f64_batch(&["a"], &[0.0]);
3731 agg.update(&[0], &[batch2], false).unwrap();
3732
3733 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3734 assert!((val - 0.0).abs() < 1e-10, "expected 0.0, got {}", val);
3735
3736 agg.snapshot();
3738 let batch3 = make_f64_batch(&["a"], &[0.5]);
3739 let changed = agg.update(&[0], &[batch3], false).unwrap();
3740 assert!(!changed);
3741 assert!(agg.is_stable());
3742 }
3743
3744 #[test]
3745 fn test_monotonic_nor_clamping() {
3746 let mut agg = MonotonicAggState::new(make_nor_binding());
3748 let batch = make_f64_batch(&["a"], &[1.5]);
3749 agg.update(&[0], &[batch], false).unwrap();
3750 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3751 assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
3752 }
3753
3754 #[test]
3755 fn test_monotonic_nor_absorbing() {
3756 let mut agg = MonotonicAggState::new(make_nor_binding());
3758 let batch1 = make_f64_batch(&["a"], &[0.3]);
3759 agg.update(&[0], &[batch1], false).unwrap();
3760 let batch2 = make_f64_batch(&["a"], &[1.0]);
3761 agg.update(&[0], &[batch2], false).unwrap();
3762 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3763 assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
3764 }
3765
3766 #[test]
3769 fn test_monotonic_agg_strict_nor_rejects() {
3770 let mut agg = MonotonicAggState::new(make_nor_binding());
3771 let batch = make_f64_batch(&["a"], &[1.5]);
3772 let result = agg.update(&[0], &[batch], true);
3773 assert!(result.is_err());
3774 let err = result.unwrap_err().to_string();
3775 assert!(
3776 err.contains("strict_probability_domain"),
3777 "Expected strict error, got: {}",
3778 err
3779 );
3780 }
3781
3782 #[test]
3783 fn test_monotonic_agg_strict_prod_rejects() {
3784 let mut agg = MonotonicAggState::new(make_prod_binding());
3785 let batch = make_f64_batch(&["a"], &[2.0]);
3786 let result = agg.update(&[0], &[batch], true);
3787 assert!(result.is_err());
3788 let err = result.unwrap_err().to_string();
3789 assert!(
3790 err.contains("strict_probability_domain"),
3791 "Expected strict error, got: {}",
3792 err
3793 );
3794 }
3795
3796 #[test]
3797 fn test_monotonic_agg_strict_accepts_valid() {
3798 let mut agg = MonotonicAggState::new(make_nor_binding());
3799 let batch = make_f64_batch(&["a"], &[0.5]);
3800 let result = agg.update(&[0], &[batch], true);
3801 assert!(result.is_ok());
3802 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3803 assert!((val - 0.5).abs() < 1e-10, "expected 0.5, got {}", val);
3804 }
3805
3806 fn make_vid_prob_batch(vids: &[u64], probs: &[f64]) -> RecordBatch {
3809 use arrow_array::UInt64Array;
3810 let schema = Arc::new(Schema::new(vec![
3811 Field::new("vid", DataType::UInt64, true),
3812 Field::new("prob", DataType::Float64, true),
3813 ]));
3814 RecordBatch::try_new(
3815 schema,
3816 vec![
3817 Arc::new(UInt64Array::from(vids.to_vec())),
3818 Arc::new(Float64Array::from(probs.to_vec())),
3819 ],
3820 )
3821 .unwrap()
3822 }
3823
3824 #[test]
3825 fn test_prob_complement_basic() {
3826 let body = make_vid_prob_batch(&[1, 2], &[0.9, 0.8]);
3828 let neg = make_vid_prob_batch(&[1], &[0.7]);
3829 let join_cols = vec![("vid".to_string(), "vid".to_string())];
3830 let result = apply_prob_complement_composite(
3831 vec![body],
3832 &[neg],
3833 &join_cols,
3834 "prob",
3835 "__complement_0",
3836 )
3837 .unwrap();
3838 assert_eq!(result.len(), 1);
3839 let batch = &result[0];
3840 let complement = batch
3841 .column_by_name("__complement_0")
3842 .unwrap()
3843 .as_any()
3844 .downcast_ref::<Float64Array>()
3845 .unwrap();
3846 assert!(
3848 (complement.value(0) - 0.3).abs() < 1e-10,
3849 "expected 0.3, got {}",
3850 complement.value(0)
3851 );
3852 assert!(
3854 (complement.value(1) - 1.0).abs() < 1e-10,
3855 "expected 1.0, got {}",
3856 complement.value(1)
3857 );
3858 }
3859
3860 #[test]
3861 fn test_prob_complement_noisy_or_duplicates() {
3862 let body = make_vid_prob_batch(&[1], &[0.9]);
3866 let neg = make_vid_prob_batch(&[1, 1], &[0.3, 0.5]);
3867 let join_cols = vec![("vid".to_string(), "vid".to_string())];
3868 let result = apply_prob_complement_composite(
3869 vec![body],
3870 &[neg],
3871 &join_cols,
3872 "prob",
3873 "__complement_0",
3874 )
3875 .unwrap();
3876 let batch = &result[0];
3877 let complement = batch
3878 .column_by_name("__complement_0")
3879 .unwrap()
3880 .as_any()
3881 .downcast_ref::<Float64Array>()
3882 .unwrap();
3883 assert!(
3884 (complement.value(0) - 0.35).abs() < 1e-10,
3885 "expected 0.35, got {}",
3886 complement.value(0)
3887 );
3888 }
3889
3890 #[test]
3891 fn test_prob_complement_empty_neg() {
3892 let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
3894 let join_cols = vec![("vid".to_string(), "vid".to_string())];
3895 let result =
3896 apply_prob_complement_composite(vec![body], &[], &join_cols, "prob", "__complement_0")
3897 .unwrap();
3898 let batch = &result[0];
3899 let complement = batch
3900 .column_by_name("__complement_0")
3901 .unwrap()
3902 .as_any()
3903 .downcast_ref::<Float64Array>()
3904 .unwrap();
3905 for i in 0..2 {
3906 assert!(
3907 (complement.value(i) - 1.0).abs() < 1e-10,
3908 "row {}: expected 1.0, got {}",
3909 i,
3910 complement.value(i)
3911 );
3912 }
3913 }
3914
3915 #[test]
3916 fn test_anti_join_basic() {
3917 use arrow_array::UInt64Array;
3919 let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
3920 let neg = make_vid_prob_batch(&[2], &[0.0]);
3921 let join_cols = vec![("vid".to_string(), "vid".to_string())];
3922 let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
3923 assert_eq!(result.len(), 1);
3924 let batch = &result[0];
3925 assert_eq!(batch.num_rows(), 2);
3926 let vids = batch
3927 .column_by_name("vid")
3928 .unwrap()
3929 .as_any()
3930 .downcast_ref::<UInt64Array>()
3931 .unwrap();
3932 assert_eq!(vids.value(0), 1);
3933 assert_eq!(vids.value(1), 3);
3934 }
3935
3936 #[test]
3937 fn test_anti_join_empty_neg() {
3938 let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
3940 let join_cols = vec![("vid".to_string(), "vid".to_string())];
3941 let result = apply_anti_join_composite(vec![body], &[], &join_cols).unwrap();
3942 assert_eq!(result.len(), 1);
3943 assert_eq!(result[0].num_rows(), 3);
3944 }
3945
3946 #[test]
3947 fn test_anti_join_all_excluded() {
3948 let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
3950 let neg = make_vid_prob_batch(&[1, 2], &[0.0, 0.0]);
3951 let join_cols = vec![("vid".to_string(), "vid".to_string())];
3952 let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
3953 let total: usize = result.iter().map(|b| b.num_rows()).sum();
3954 assert_eq!(total, 0);
3955 }
3956
3957 #[test]
3958 fn test_multiply_prob_single_complement() {
3959 let body = make_vid_prob_batch(&[1], &[0.8]);
3961 let complement_arr = Float64Array::from(vec![0.5]);
3963 let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
3964 cols.push(Arc::new(complement_arr));
3965 let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
3966 fields.push(Arc::new(Field::new(
3967 "__complement_0",
3968 DataType::Float64,
3969 true,
3970 )));
3971 let schema = Arc::new(Schema::new(fields));
3972 let batch = RecordBatch::try_new(schema, cols).unwrap();
3973
3974 let result =
3975 multiply_prob_factors(vec![batch], Some("prob"), &["__complement_0".to_string()])
3976 .unwrap();
3977 assert_eq!(result.len(), 1);
3978 let out = &result[0];
3979 assert!(out.column_by_name("__complement_0").is_none());
3981 let prob = out
3982 .column_by_name("prob")
3983 .unwrap()
3984 .as_any()
3985 .downcast_ref::<Float64Array>()
3986 .unwrap();
3987 assert!(
3988 (prob.value(0) - 0.4).abs() < 1e-10,
3989 "expected 0.4, got {}",
3990 prob.value(0)
3991 );
3992 }
3993
3994 #[test]
3995 fn test_multiply_prob_multiple_complements() {
3996 let body = make_vid_prob_batch(&[1], &[0.8]);
3998 let c1 = Float64Array::from(vec![0.5]);
3999 let c2 = Float64Array::from(vec![0.6]);
4000 let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
4001 cols.push(Arc::new(c1));
4002 cols.push(Arc::new(c2));
4003 let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
4004 fields.push(Arc::new(Field::new("__c1", DataType::Float64, true)));
4005 fields.push(Arc::new(Field::new("__c2", DataType::Float64, true)));
4006 let schema = Arc::new(Schema::new(fields));
4007 let batch = RecordBatch::try_new(schema, cols).unwrap();
4008
4009 let result = multiply_prob_factors(
4010 vec![batch],
4011 Some("prob"),
4012 &["__c1".to_string(), "__c2".to_string()],
4013 )
4014 .unwrap();
4015 let out = &result[0];
4016 assert!(out.column_by_name("__c1").is_none());
4017 assert!(out.column_by_name("__c2").is_none());
4018 let prob = out
4019 .column_by_name("prob")
4020 .unwrap()
4021 .as_any()
4022 .downcast_ref::<Float64Array>()
4023 .unwrap();
4024 assert!(
4025 (prob.value(0) - 0.24).abs() < 1e-10,
4026 "expected 0.24, got {}",
4027 prob.value(0)
4028 );
4029 }
4030
4031 #[test]
4032 fn test_multiply_prob_no_prob_column() {
4033 use arrow_array::UInt64Array;
4035 let schema = Arc::new(Schema::new(vec![
4036 Field::new("vid", DataType::UInt64, true),
4037 Field::new("__c1", DataType::Float64, true),
4038 ]));
4039 let batch = RecordBatch::try_new(
4040 schema,
4041 vec![
4042 Arc::new(UInt64Array::from(vec![1u64])),
4043 Arc::new(Float64Array::from(vec![0.7])),
4044 ],
4045 )
4046 .unwrap();
4047
4048 let result = multiply_prob_factors(vec![batch], None, &["__c1".to_string()]).unwrap();
4049 let out = &result[0];
4050 assert!(out.column_by_name("__c1").is_none());
4052 assert_eq!(out.num_columns(), 1);
4054 }
4055}