1use crate::query::df_graph::GraphExecutionContext;
10use crate::query::df_graph::common::{
11 ScalarKey, arrow_err, collect_all_partitions, compute_plan_properties, execute_subplan,
12 extract_scalar_key,
13};
14use crate::query::df_graph::locy_best_by::{BestByExec, SortCriterion};
15use crate::query::df_graph::locy_errors::LocyRuntimeError;
16use crate::query::df_graph::locy_explain::{
17 ProofTerm, ProvenanceAnnotation, ProvenanceStore, compute_proof_probability,
18};
19use crate::query::df_graph::locy_fold::{FoldBinding, FoldExec};
20use crate::query::df_graph::locy_priority::PriorityExec;
21use crate::query::planner::LogicalPlan;
22use arrow_array::RecordBatch;
23use arrow_row::{RowConverter, SortField};
24use arrow_schema::SchemaRef;
25use datafusion::common::JoinType;
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
29use datafusion::physical_plan::memory::MemoryStream;
30use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
31use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
32use futures::Stream;
33use parking_lot::RwLock;
34use std::any::Any;
35use std::collections::{HashMap, HashSet};
36use std::fmt;
37use std::pin::Pin;
38use std::sync::{Arc, RwLock as StdRwLock};
39use std::task::{Context, Poll};
40use std::time::{Duration, Instant};
41use uni_common::Value;
42use uni_common::core::schema::Schema as UniSchema;
43use uni_cypher::ast::Expr;
44use uni_locy::RuntimeWarning;
45use uni_store::storage::manager::StorageManager;
46
47#[derive(Debug)]
57pub struct DerivedScanEntry {
58 pub scan_index: usize,
60 pub rule_name: String,
62 pub is_self_ref: bool,
64 pub data: Arc<RwLock<Vec<RecordBatch>>>,
66 pub schema: SchemaRef,
68}
69
70#[derive(Debug, Default)]
77pub struct DerivedScanRegistry {
78 entries: Vec<DerivedScanEntry>,
79}
80
81impl DerivedScanRegistry {
82 pub fn new() -> Self {
84 Self::default()
85 }
86
87 pub fn add(&mut self, entry: DerivedScanEntry) {
89 self.entries.push(entry);
90 }
91
92 pub fn get(&self, scan_index: usize) -> Option<&DerivedScanEntry> {
94 self.entries.iter().find(|e| e.scan_index == scan_index)
95 }
96
97 pub fn write_data(&self, scan_index: usize, batches: Vec<RecordBatch>) {
99 if let Some(entry) = self.get(scan_index) {
100 let mut guard = entry.data.write();
101 *guard = batches;
102 }
103 }
104
105 pub fn entries_for_rule(&self, rule_name: &str) -> Vec<&DerivedScanEntry> {
107 self.entries
108 .iter()
109 .filter(|e| e.rule_name == rule_name)
110 .collect()
111 }
112}
113
114#[derive(Debug, Clone)]
120pub struct MonotonicFoldBinding {
121 pub fold_name: String,
122 pub kind: crate::query::df_graph::locy_fold::FoldAggKind,
123 pub input_col_index: usize,
124 pub input_col_name: Option<String>,
126}
127
128#[derive(Debug)]
134pub struct MonotonicAggState {
135 accumulators: HashMap<(Vec<ScalarKey>, String), f64>,
137 prev_snapshot: HashMap<(Vec<ScalarKey>, String), f64>,
139 bindings: Vec<MonotonicFoldBinding>,
141}
142
143impl MonotonicAggState {
144 pub fn new(bindings: Vec<MonotonicFoldBinding>) -> Self {
146 Self {
147 accumulators: HashMap::new(),
148 prev_snapshot: HashMap::new(),
149 bindings,
150 }
151 }
152
153 pub fn update(
159 &mut self,
160 key_indices: &[usize],
161 delta_batches: &[RecordBatch],
162 strict: bool,
163 ) -> DFResult<bool> {
164 use crate::query::df_graph::locy_fold::FoldAggKind;
165
166 let mut changed = false;
167 for batch in delta_batches {
168 for row_idx in 0..batch.num_rows() {
169 let group_key = extract_scalar_key(batch, key_indices, row_idx);
170 for binding in &self.bindings {
171 let idx = binding
173 .input_col_name
174 .as_ref()
175 .and_then(|name| batch.schema().index_of(name).ok())
176 .unwrap_or(binding.input_col_index);
177 if idx >= batch.num_columns() {
178 continue;
179 }
180 let col = batch.column(idx);
181 let val = extract_f64(col.as_ref(), row_idx);
182 if let Some(val) = val {
183 let map_key = (group_key.clone(), binding.fold_name.clone());
184 let entry = self
185 .accumulators
186 .entry(map_key)
187 .or_insert(binding.kind.identity().unwrap_or(0.0));
188 let old = *entry;
189 match binding.kind {
190 FoldAggKind::Sum | FoldAggKind::Count => *entry += val,
191 FoldAggKind::Max => {
192 if val > *entry {
193 *entry = val;
194 }
195 }
196 FoldAggKind::Min => {
197 if val < *entry {
198 *entry = val;
199 }
200 }
201 FoldAggKind::Nor => {
202 if strict && !(0.0..=1.0).contains(&val) {
203 return Err(datafusion::error::DataFusionError::Execution(
204 format!(
205 "strict_probability_domain: MNOR input {val} is outside [0, 1]"
206 ),
207 ));
208 }
209 if !strict && !(0.0..=1.0).contains(&val) {
210 tracing::warn!(
211 "MNOR input {val} outside [0,1], clamped to {}",
212 val.clamp(0.0, 1.0)
213 );
214 }
215 let p = val.clamp(0.0, 1.0);
216 *entry = 1.0 - (1.0 - *entry) * (1.0 - p);
217 }
218 FoldAggKind::Prod => {
219 if strict && !(0.0..=1.0).contains(&val) {
220 return Err(datafusion::error::DataFusionError::Execution(
221 format!(
222 "strict_probability_domain: MPROD input {val} is outside [0, 1]"
223 ),
224 ));
225 }
226 if !strict && !(0.0..=1.0).contains(&val) {
227 tracing::warn!(
228 "MPROD input {val} outside [0,1], clamped to {}",
229 val.clamp(0.0, 1.0)
230 );
231 }
232 let p = val.clamp(0.0, 1.0);
233 *entry *= p;
234 }
235 _ => {}
236 }
237 if (*entry - old).abs() > f64::EPSILON {
238 changed = true;
239 }
240 }
241 }
242 }
243 }
244 Ok(changed)
245 }
246
247 pub fn snapshot(&mut self) {
249 self.prev_snapshot = self.accumulators.clone();
250 }
251
252 pub fn is_stable(&self) -> bool {
254 if self.accumulators.len() != self.prev_snapshot.len() {
255 return false;
256 }
257 for (key, val) in &self.accumulators {
258 match self.prev_snapshot.get(key) {
259 Some(prev) if (*val - *prev).abs() <= f64::EPSILON => {}
260 _ => return false,
261 }
262 }
263 true
264 }
265
266 #[cfg(test)]
268 pub(crate) fn get_accumulator(&self, key: &(Vec<ScalarKey>, String)) -> Option<f64> {
269 self.accumulators.get(key).copied()
270 }
271}
272
273fn extract_f64(col: &dyn arrow_array::Array, row_idx: usize) -> Option<f64> {
275 if col.is_null(row_idx) {
276 return None;
277 }
278 if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Float64Array>() {
279 Some(arr.value(row_idx))
280 } else {
281 col.as_any()
282 .downcast_ref::<arrow_array::Int64Array>()
283 .map(|arr| arr.value(row_idx) as f64)
284 }
285}
286
287struct RowDedupState {
297 converter: RowConverter,
298 seen: HashSet<Box<[u8]>>,
299}
300
301impl RowDedupState {
302 fn try_new(schema: &SchemaRef) -> Option<Self> {
307 let fields: Vec<SortField> = schema
308 .fields()
309 .iter()
310 .map(|f| SortField::new(f.data_type().clone()))
311 .collect();
312 match RowConverter::new(fields) {
313 Ok(converter) => Some(Self {
314 converter,
315 seen: HashSet::new(),
316 }),
317 Err(e) => {
318 tracing::warn!(
319 "RowDedupState: RowConverter unsupported for schema, falling back to legacy dedup: {}",
320 e
321 );
322 None
323 }
324 }
325 }
326
327 fn ingest_existing(&mut self, facts: &[RecordBatch], _schema: &SchemaRef) {
332 self.seen.clear();
333 for batch in facts {
334 if batch.num_rows() == 0 {
335 continue;
336 }
337 let arrays: Vec<_> = batch.columns().to_vec();
338 if let Ok(rows) = self.converter.convert_columns(&arrays) {
339 for row_idx in 0..batch.num_rows() {
340 let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
341 self.seen.insert(row_bytes);
342 }
343 }
344 }
345 }
346
347 fn compute_delta(
353 &mut self,
354 candidates: &[RecordBatch],
355 schema: &SchemaRef,
356 ) -> DFResult<Vec<RecordBatch>> {
357 let mut delta_batches = Vec::new();
358 for batch in candidates {
359 if batch.num_rows() == 0 {
360 continue;
361 }
362
363 let arrays: Vec<_> = batch.columns().to_vec();
365 let rows = self.converter.convert_columns(&arrays).map_err(arrow_err)?;
366
367 let mut keep = Vec::with_capacity(batch.num_rows());
369 for row_idx in 0..batch.num_rows() {
370 let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
371 keep.push(self.seen.insert(row_bytes));
372 }
373
374 let keep_mask = arrow_array::BooleanArray::from(keep);
375 let new_cols = batch
376 .columns()
377 .iter()
378 .map(|col| {
379 arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
380 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
381 })
382 })
383 .collect::<DFResult<Vec<_>>>()?;
384
385 if new_cols.first().is_some_and(|c| !c.is_empty()) {
386 let filtered = RecordBatch::try_new(Arc::clone(schema), new_cols).map_err(|e| {
387 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
388 })?;
389 delta_batches.push(filtered);
390 }
391 }
392 Ok(delta_batches)
393 }
394}
395
396pub struct FixpointState {
406 rule_name: String,
407 facts: Vec<RecordBatch>,
408 delta: Vec<RecordBatch>,
409 schema: SchemaRef,
410 key_column_indices: Vec<usize>,
411 key_column_names: Vec<String>,
413 all_column_indices: Vec<usize>,
415 facts_bytes: usize,
417 max_derived_bytes: usize,
419 monotonic_agg: Option<MonotonicAggState>,
421 row_dedup: Option<RowDedupState>,
423 strict_probability_domain: bool,
425}
426
427impl FixpointState {
428 pub fn new(
430 rule_name: String,
431 schema: SchemaRef,
432 key_column_indices: Vec<usize>,
433 max_derived_bytes: usize,
434 monotonic_agg: Option<MonotonicAggState>,
435 strict_probability_domain: bool,
436 ) -> Self {
437 let num_cols = schema.fields().len();
438 let row_dedup = RowDedupState::try_new(&schema);
439 let key_column_names: Vec<String> = key_column_indices
440 .iter()
441 .filter_map(|&i| schema.fields().get(i).map(|f| f.name().clone()))
442 .collect();
443 Self {
444 rule_name,
445 facts: Vec::new(),
446 delta: Vec::new(),
447 schema,
448 key_column_indices,
449 key_column_names,
450 all_column_indices: (0..num_cols).collect(),
451 facts_bytes: 0,
452 max_derived_bytes,
453 monotonic_agg,
454 row_dedup,
455 strict_probability_domain,
456 }
457 }
458
459 fn reconcile_schema(&mut self, actual_schema: &SchemaRef) {
466 if self.schema.fields() != actual_schema.fields() {
467 tracing::debug!(
468 rule = %self.rule_name,
469 "Reconciling fixpoint schema from physical plan output",
470 );
471 self.schema = Arc::clone(actual_schema);
472 self.row_dedup = RowDedupState::try_new(&self.schema);
473 let new_indices: Vec<usize> = self
477 .key_column_names
478 .iter()
479 .filter_map(|name| actual_schema.index_of(name).ok())
480 .collect();
481 if new_indices.len() == self.key_column_names.len() {
482 self.key_column_indices = new_indices;
483 }
484 let num_cols = actual_schema.fields().len();
486 self.all_column_indices = (0..num_cols).collect();
487 }
488 }
489
490 pub async fn merge_delta(
494 &mut self,
495 candidates: Vec<RecordBatch>,
496 task_ctx: Option<Arc<TaskContext>>,
497 ) -> DFResult<bool> {
498 if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
499 self.delta.clear();
500 return Ok(false);
501 }
502
503 if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
507 self.reconcile_schema(&first.schema());
508 }
509
510 let candidates = round_float_columns(&candidates);
512
513 let delta = self.compute_delta(&candidates, task_ctx.as_ref()).await?;
515
516 if delta.is_empty() || delta.iter().all(|b| b.num_rows() == 0) {
517 self.delta.clear();
518 if let Some(ref mut agg) = self.monotonic_agg {
520 agg.snapshot();
521 }
522 return Ok(false);
523 }
524
525 let delta_bytes: usize = delta.iter().map(batch_byte_size).sum();
527 if self.facts_bytes + delta_bytes > self.max_derived_bytes {
528 return Err(datafusion::error::DataFusionError::Execution(
529 LocyRuntimeError::MemoryLimitExceeded {
530 rule: self.rule_name.clone(),
531 bytes: self.facts_bytes + delta_bytes,
532 limit: self.max_derived_bytes,
533 }
534 .to_string(),
535 ));
536 }
537
538 if let Some(ref mut agg) = self.monotonic_agg {
540 agg.snapshot();
541 agg.update(
542 &self.key_column_indices,
543 &delta,
544 self.strict_probability_domain,
545 )?;
546 }
547
548 self.facts_bytes += delta_bytes;
550 self.facts.extend(delta.iter().cloned());
551 self.delta = delta;
552
553 Ok(true)
554 }
555
556 async fn compute_delta(
563 &mut self,
564 candidates: &[RecordBatch],
565 task_ctx: Option<&Arc<TaskContext>>,
566 ) -> DFResult<Vec<RecordBatch>> {
567 let total_existing: usize = self.facts.iter().map(|b| b.num_rows()).sum();
568 if total_existing >= DEDUP_ANTI_JOIN_THRESHOLD
569 && let Some(ctx) = task_ctx
570 {
571 return arrow_left_anti_dedup(candidates.to_vec(), &self.facts, &self.schema, ctx)
572 .await;
573 }
574 if let Some(ref mut rd) = self.row_dedup {
575 rd.compute_delta(candidates, &self.schema)
576 } else {
577 self.compute_delta_legacy(candidates)
578 }
579 }
580
581 fn compute_delta_legacy(&self, candidates: &[RecordBatch]) -> DFResult<Vec<RecordBatch>> {
585 let mut existing: HashSet<Vec<ScalarKey>> = HashSet::new();
587 for batch in &self.facts {
588 for row_idx in 0..batch.num_rows() {
589 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
590 existing.insert(key);
591 }
592 }
593
594 let mut delta_batches = Vec::new();
595 for batch in candidates {
596 if batch.num_rows() == 0 {
597 continue;
598 }
599 let mut keep = Vec::with_capacity(batch.num_rows());
601 for row_idx in 0..batch.num_rows() {
602 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
603 keep.push(!existing.contains(&key));
604 }
605
606 for (row_idx, kept) in keep.iter_mut().enumerate() {
608 if *kept {
609 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
610 if !existing.insert(key) {
611 *kept = false;
612 }
613 }
614 }
615
616 let keep_mask = arrow_array::BooleanArray::from(keep);
617 let new_rows = batch
618 .columns()
619 .iter()
620 .map(|col| {
621 arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
622 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
623 })
624 })
625 .collect::<DFResult<Vec<_>>>()?;
626
627 if new_rows.first().is_some_and(|c| !c.is_empty()) {
628 let filtered =
629 RecordBatch::try_new(Arc::clone(&self.schema), new_rows).map_err(|e| {
630 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
631 })?;
632 delta_batches.push(filtered);
633 }
634 }
635
636 Ok(delta_batches)
637 }
638
639 pub fn is_converged(&self) -> bool {
641 let delta_empty = self.delta.is_empty() || self.delta.iter().all(|b| b.num_rows() == 0);
642 let agg_stable = self.monotonic_agg.as_ref().is_none_or(|a| a.is_stable());
643 delta_empty && agg_stable
644 }
645
646 pub fn all_facts(&self) -> &[RecordBatch] {
648 &self.facts
649 }
650
651 pub fn all_delta(&self) -> &[RecordBatch] {
653 &self.delta
654 }
655
656 pub fn into_facts(self) -> Vec<RecordBatch> {
658 self.facts
659 }
660
661 pub fn merge_best_by(
672 &mut self,
673 candidates: Vec<RecordBatch>,
674 sort_criteria: &[SortCriterion],
675 ) -> DFResult<bool> {
676 if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
677 self.delta.clear();
678 return Ok(false);
679 }
680
681 if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
683 self.reconcile_schema(&first.schema());
684 }
685
686 let candidates = round_float_columns(&candidates);
688
689 let old_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> =
691 self.build_key_criteria_map(sort_criteria);
692
693 let mut all_batches = self.facts.clone();
695 all_batches.extend(candidates);
696 let all_batches: Vec<_> = all_batches
697 .into_iter()
698 .filter(|b| b.num_rows() > 0)
699 .collect();
700 if all_batches.is_empty() {
701 self.delta.clear();
702 return Ok(false);
703 }
704
705 let combined = arrow::compute::concat_batches(&self.schema, &all_batches)
706 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
707
708 if combined.num_rows() == 0 {
709 self.delta.clear();
710 return Ok(false);
711 }
712
713 let mut sort_columns = Vec::new();
716 for &ki in &self.key_column_indices {
717 if ki >= combined.num_columns() {
718 continue;
719 }
720 sort_columns.push(arrow::compute::SortColumn {
721 values: Arc::clone(combined.column(ki)),
722 options: Some(arrow::compute::SortOptions {
723 descending: false,
724 nulls_first: false,
725 }),
726 });
727 }
728 for criterion in sort_criteria {
729 if criterion.col_index >= combined.num_columns() {
730 continue;
731 }
732 sort_columns.push(arrow::compute::SortColumn {
733 values: Arc::clone(combined.column(criterion.col_index)),
734 options: Some(arrow::compute::SortOptions {
735 descending: !criterion.ascending,
736 nulls_first: criterion.nulls_first,
737 }),
738 });
739 }
740
741 let sorted_indices =
742 arrow::compute::lexsort_to_indices(&sort_columns, None).map_err(arrow_err)?;
743 let sorted_columns: Vec<_> = combined
744 .columns()
745 .iter()
746 .map(|col| arrow::compute::take(col.as_ref(), &sorted_indices, None))
747 .collect::<Result<Vec<_>, _>>()
748 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
749 let sorted = RecordBatch::try_new(Arc::clone(&self.schema), sorted_columns)
750 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
751
752 let mut keep_indices: Vec<u32> = Vec::new();
754 let mut prev_key: Option<Vec<ScalarKey>> = None;
755 for row_idx in 0..sorted.num_rows() {
756 let key = extract_scalar_key(&sorted, &self.key_column_indices, row_idx);
757 let is_new_group = match &prev_key {
758 None => true,
759 Some(prev) => *prev != key,
760 };
761 if is_new_group {
762 keep_indices.push(row_idx as u32);
763 prev_key = Some(key);
764 }
765 }
766
767 let keep_array = arrow_array::UInt32Array::from(keep_indices);
768 let output_columns: Vec<_> = sorted
769 .columns()
770 .iter()
771 .map(|col| arrow::compute::take(col.as_ref(), &keep_array, None))
772 .collect::<Result<Vec<_>, _>>()
773 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
774 let pruned = RecordBatch::try_new(Arc::clone(&self.schema), output_columns)
775 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
776
777 let new_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> = {
779 let mut map = HashMap::new();
780 for row_idx in 0..pruned.num_rows() {
781 let key = extract_scalar_key(&pruned, &self.key_column_indices, row_idx);
782 let criteria: Vec<ScalarKey> = sort_criteria
783 .iter()
784 .flat_map(|c| extract_scalar_key(&pruned, &[c.col_index], row_idx))
785 .collect();
786 map.insert(key, criteria);
787 }
788 map
789 };
790 let changed = old_best != new_best;
791
792 tracing::debug!(
793 rule = %self.rule_name,
794 old_keys = old_best.len(),
795 new_keys = new_best.len(),
796 changed = changed,
797 "BEST BY merge"
798 );
799
800 self.facts_bytes = batch_byte_size(&pruned);
802 self.facts = vec![pruned];
803 if changed {
804 self.delta = self.facts.clone();
807 } else {
808 self.delta.clear();
809 }
810
811 self.row_dedup = RowDedupState::try_new(&self.schema);
813 if let Some(ref mut rd) = self.row_dedup {
814 rd.ingest_existing(&self.facts, &self.schema);
815 }
816
817 Ok(changed)
818 }
819
820 fn build_key_criteria_map(
822 &self,
823 sort_criteria: &[SortCriterion],
824 ) -> HashMap<Vec<ScalarKey>, Vec<ScalarKey>> {
825 let mut map = HashMap::new();
826 for batch in &self.facts {
827 for row_idx in 0..batch.num_rows() {
828 let key = extract_scalar_key(batch, &self.key_column_indices, row_idx);
829 let criteria: Vec<ScalarKey> = sort_criteria
830 .iter()
831 .flat_map(|c| extract_scalar_key(batch, &[c.col_index], row_idx))
832 .collect();
833 map.insert(key, criteria);
834 }
835 }
836 map
837 }
838}
839
840fn batch_byte_size(batch: &RecordBatch) -> usize {
842 batch
843 .columns()
844 .iter()
845 .map(|col| col.get_buffer_memory_size())
846 .sum()
847}
848
849fn round_float_columns(batches: &[RecordBatch]) -> Vec<RecordBatch> {
855 batches
856 .iter()
857 .map(|batch| {
858 let schema = batch.schema();
859 let has_float = schema
860 .fields()
861 .iter()
862 .any(|f| *f.data_type() == arrow_schema::DataType::Float64);
863 if !has_float {
864 return batch.clone();
865 }
866
867 let columns: Vec<arrow_array::ArrayRef> = batch
868 .columns()
869 .iter()
870 .enumerate()
871 .map(|(i, col)| {
872 if *schema.field(i).data_type() == arrow_schema::DataType::Float64 {
873 let arr = col
874 .as_any()
875 .downcast_ref::<arrow_array::Float64Array>()
876 .unwrap();
877 let rounded: arrow_array::Float64Array = arr
878 .iter()
879 .map(|v| v.map(|f| (f * 1e12).round() / 1e12))
880 .collect();
881 Arc::new(rounded) as arrow_array::ArrayRef
882 } else {
883 Arc::clone(col)
884 }
885 })
886 .collect();
887
888 RecordBatch::try_new(schema, columns).unwrap_or_else(|_| batch.clone())
889 })
890 .collect()
891}
892
893const DEDUP_ANTI_JOIN_THRESHOLD: usize = 300;
903
904async fn arrow_left_anti_dedup(
909 candidates: Vec<RecordBatch>,
910 existing: &[RecordBatch],
911 schema: &SchemaRef,
912 task_ctx: &Arc<TaskContext>,
913) -> DFResult<Vec<RecordBatch>> {
914 if existing.is_empty() || existing.iter().all(|b| b.num_rows() == 0) {
915 return Ok(candidates);
916 }
917
918 let left: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(candidates, Arc::clone(schema)));
919 let right: Arc<dyn ExecutionPlan> =
920 Arc::new(InMemoryExec::new(existing.to_vec(), Arc::clone(schema)));
921
922 let on: Vec<(
923 Arc<dyn datafusion::physical_plan::PhysicalExpr>,
924 Arc<dyn datafusion::physical_plan::PhysicalExpr>,
925 )> = schema
926 .fields()
927 .iter()
928 .enumerate()
929 .map(|(i, field)| {
930 let l: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
931 datafusion::physical_plan::expressions::Column::new(field.name(), i),
932 );
933 let r: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
934 datafusion::physical_plan::expressions::Column::new(field.name(), i),
935 );
936 (l, r)
937 })
938 .collect();
939
940 if on.is_empty() {
941 return Ok(vec![]);
942 }
943
944 let join = HashJoinExec::try_new(
945 left,
946 right,
947 on,
948 None,
949 &JoinType::LeftAnti,
950 None,
951 PartitionMode::CollectLeft,
952 datafusion::common::NullEquality::NullEqualsNull,
953 )?;
954
955 let join_arc: Arc<dyn ExecutionPlan> = Arc::new(join);
956 collect_all_partitions(&join_arc, task_ctx.clone()).await
957}
958
959#[derive(Debug, Clone)]
965pub struct IsRefBinding {
966 pub derived_scan_index: usize,
968 pub rule_name: String,
970 pub is_self_ref: bool,
972 pub negated: bool,
974 pub anti_join_cols: Vec<(String, String)>,
980 pub target_has_prob: bool,
982 pub target_prob_col: Option<String>,
984 pub provenance_join_cols: Vec<(String, String)>,
989}
990
991#[derive(Debug)]
993pub struct FixpointClausePlan {
994 pub body_logical: LogicalPlan,
996 pub is_ref_bindings: Vec<IsRefBinding>,
998 pub priority: Option<i64>,
1000 pub along_bindings: Vec<String>,
1002}
1003
1004#[derive(Debug)]
1006pub struct FixpointRulePlan {
1007 pub name: String,
1009 pub clauses: Vec<FixpointClausePlan>,
1011 pub yield_schema: SchemaRef,
1013 pub key_column_indices: Vec<usize>,
1015 pub priority: Option<i64>,
1017 pub has_fold: bool,
1019 pub fold_bindings: Vec<FoldBinding>,
1021 pub having: Vec<Expr>,
1023 pub has_best_by: bool,
1025 pub best_by_criteria: Vec<SortCriterion>,
1027 pub has_priority: bool,
1029 pub deterministic: bool,
1033 pub prob_column_name: Option<String>,
1035}
1036
1037#[expect(clippy::too_many_arguments, reason = "Fixpoint loop needs all context")]
1046async fn run_fixpoint_loop(
1047 rules: Vec<FixpointRulePlan>,
1048 max_iterations: usize,
1049 timeout: Duration,
1050 graph_ctx: Arc<GraphExecutionContext>,
1051 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
1052 storage: Arc<StorageManager>,
1053 schema_info: Arc<UniSchema>,
1054 params: HashMap<String, Value>,
1055 registry: Arc<DerivedScanRegistry>,
1056 output_schema: SchemaRef,
1057 max_derived_bytes: usize,
1058 derivation_tracker: Option<Arc<ProvenanceStore>>,
1059 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
1060 strict_probability_domain: bool,
1061 probability_epsilon: f64,
1062 exact_probability: bool,
1063 max_bdd_variables: usize,
1064 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
1065 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
1066 top_k_proofs: usize,
1067) -> DFResult<Vec<RecordBatch>> {
1068 let start = Instant::now();
1069 let task_ctx = session_ctx.read().task_ctx();
1070
1071 let mut states: Vec<FixpointState> = rules
1073 .iter()
1074 .map(|rule| {
1075 let monotonic_agg = if !rule.fold_bindings.is_empty() {
1076 let bindings: Vec<MonotonicFoldBinding> = rule
1077 .fold_bindings
1078 .iter()
1079 .map(|fb| MonotonicFoldBinding {
1080 fold_name: fb.output_name.clone(),
1081 kind: fb.kind.clone(),
1082 input_col_index: fb.input_col_index,
1083 input_col_name: fb.input_col_name.clone(),
1084 })
1085 .collect();
1086 Some(MonotonicAggState::new(bindings))
1087 } else {
1088 None
1089 };
1090 FixpointState::new(
1091 rule.name.clone(),
1092 Arc::clone(&rule.yield_schema),
1093 rule.key_column_indices.clone(),
1094 max_derived_bytes,
1095 monotonic_agg,
1096 strict_probability_domain,
1097 )
1098 })
1099 .collect();
1100
1101 let mut converged = false;
1103 let mut total_iters = 0usize;
1104 for iteration in 0..max_iterations {
1105 total_iters = iteration + 1;
1106 tracing::debug!("fixpoint iteration {}", iteration);
1107 let mut any_changed = false;
1108
1109 for rule_idx in 0..rules.len() {
1110 let rule = &rules[rule_idx];
1111
1112 update_derived_scan_handles(®istry, &states, rule_idx, &rules);
1114
1115 let mut all_candidates = Vec::new();
1117 let mut clause_candidates: Vec<Vec<RecordBatch>> = Vec::new();
1118 for clause in &rule.clauses {
1119 let mut batches = execute_subplan(
1120 &clause.body_logical,
1121 ¶ms,
1122 &HashMap::new(),
1123 &graph_ctx,
1124 &session_ctx,
1125 &storage,
1126 &schema_info,
1127 )
1128 .await?;
1129 for binding in &clause.is_ref_bindings {
1131 if binding.negated
1132 && !binding.anti_join_cols.is_empty()
1133 && let Some(entry) = registry.get(binding.derived_scan_index)
1134 {
1135 let neg_facts = entry.data.read().clone();
1136 if !neg_facts.is_empty() {
1137 if binding.target_has_prob && rule.prob_column_name.is_some() {
1138 let complement_col =
1140 format!("__prob_complement_{}", binding.rule_name);
1141 if let Some(prob_col) = &binding.target_prob_col {
1142 batches = apply_prob_complement_composite(
1143 batches,
1144 &neg_facts,
1145 &binding.anti_join_cols,
1146 prob_col,
1147 &complement_col,
1148 )?;
1149 } else {
1150 batches = apply_anti_join_composite(
1152 batches,
1153 &neg_facts,
1154 &binding.anti_join_cols,
1155 )?;
1156 }
1157 } else {
1158 batches = apply_anti_join_composite(
1160 batches,
1161 &neg_facts,
1162 &binding.anti_join_cols,
1163 )?;
1164 }
1165 }
1166 }
1167 }
1168 let complement_cols: Vec<String> = if !batches.is_empty() {
1170 batches[0]
1171 .schema()
1172 .fields()
1173 .iter()
1174 .filter(|f| f.name().starts_with("__prob_complement_"))
1175 .map(|f| f.name().clone())
1176 .collect()
1177 } else {
1178 vec![]
1179 };
1180 if !complement_cols.is_empty() {
1181 batches = multiply_prob_factors(
1182 batches,
1183 rule.prob_column_name.as_deref(),
1184 &complement_cols,
1185 )?;
1186 }
1187
1188 clause_candidates.push(batches.clone());
1189 all_candidates.extend(batches);
1190 }
1191
1192 let changed = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
1196 states[rule_idx].merge_best_by(all_candidates, &rule.best_by_criteria)?
1197 } else {
1198 states[rule_idx]
1199 .merge_delta(all_candidates, Some(Arc::clone(&task_ctx)))
1200 .await?
1201 };
1202 if changed {
1203 any_changed = true;
1204 if let Some(ref tracker) = derivation_tracker {
1206 record_provenance(
1207 tracker,
1208 rule,
1209 &states[rule_idx],
1210 &clause_candidates,
1211 iteration,
1212 ®istry,
1213 top_k_proofs,
1214 );
1215 }
1216 }
1217 }
1218
1219 if !any_changed && states.iter().all(|s| s.is_converged()) {
1221 tracing::debug!("fixpoint converged after {} iterations", iteration + 1);
1222 converged = true;
1223 break;
1224 }
1225
1226 if start.elapsed() > timeout {
1228 return Err(datafusion::error::DataFusionError::Execution(
1229 LocyRuntimeError::NonConvergence {
1230 iterations: iteration + 1,
1231 }
1232 .to_string(),
1233 ));
1234 }
1235 }
1236
1237 if let Ok(mut counts) = iteration_counts.write() {
1239 for rule in &rules {
1240 counts.insert(rule.name.clone(), total_iters);
1241 }
1242 }
1243
1244 if !converged {
1246 return Err(datafusion::error::DataFusionError::Execution(
1247 LocyRuntimeError::NonConvergence {
1248 iterations: max_iterations,
1249 }
1250 .to_string(),
1251 ));
1252 }
1253
1254 let task_ctx = session_ctx.read().task_ctx();
1256 let mut all_output = Vec::new();
1257
1258 for (rule_idx, state) in states.into_iter().enumerate() {
1259 let rule = &rules[rule_idx];
1260 let mut facts = state.into_facts();
1261 if facts.is_empty() {
1262 continue;
1263 }
1264
1265 let shared_info = if let Some(ref tracker) = derivation_tracker {
1267 detect_shared_lineage(rule, &facts, tracker, &warnings_slot)
1268 } else {
1269 None
1270 };
1271
1272 if exact_probability
1274 && let Some(ref info) = shared_info
1275 && let Some(ref tracker) = derivation_tracker
1276 {
1277 facts = apply_exact_wmc(
1278 facts,
1279 rule,
1280 info,
1281 tracker,
1282 max_bdd_variables,
1283 &warnings_slot,
1284 &approximate_slot,
1285 )?;
1286 }
1287
1288 let processed = apply_post_fixpoint_chain(
1289 facts,
1290 rule,
1291 &task_ctx,
1292 strict_probability_domain,
1293 probability_epsilon,
1294 )
1295 .await?;
1296 all_output.extend(processed);
1297 }
1298
1299 if all_output.is_empty() {
1301 all_output.push(RecordBatch::new_empty(output_schema));
1302 }
1303
1304 Ok(all_output)
1305}
1306
1307fn record_provenance(
1316 tracker: &Arc<ProvenanceStore>,
1317 rule: &FixpointRulePlan,
1318 state: &FixpointState,
1319 clause_candidates: &[Vec<RecordBatch>],
1320 iteration: usize,
1321 registry: &Arc<DerivedScanRegistry>,
1322 top_k_proofs: usize,
1323) {
1324 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1325
1326 let base_probs = if top_k_proofs > 0 {
1328 tracker.base_fact_probs()
1329 } else {
1330 HashMap::new()
1331 };
1332
1333 for delta_batch in state.all_delta() {
1334 for row_idx in 0..delta_batch.num_rows() {
1335 let row_hash = format!(
1336 "{:?}",
1337 extract_scalar_key(delta_batch, &all_indices, row_idx)
1338 )
1339 .into_bytes();
1340 let fact_row = batch_row_to_value_map(delta_batch, row_idx);
1341 let clause_index =
1342 find_clause_for_row(delta_batch, row_idx, &all_indices, clause_candidates);
1343
1344 let support = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
1345
1346 let proof_probability = if top_k_proofs > 0 {
1347 compute_proof_probability(&support, &base_probs)
1348 } else {
1349 None
1350 };
1351
1352 let entry = ProvenanceAnnotation {
1353 rule_name: rule.name.clone(),
1354 clause_index,
1355 support,
1356 along_values: {
1357 let along_names: Vec<String> = rule
1358 .clauses
1359 .get(clause_index)
1360 .map(|c| c.along_bindings.clone())
1361 .unwrap_or_default();
1362 along_names
1363 .iter()
1364 .filter_map(|name| fact_row.get(name).map(|v| (name.clone(), v.clone())))
1365 .collect()
1366 },
1367 iteration,
1368 fact_row,
1369 proof_probability,
1370 };
1371 if top_k_proofs > 0 {
1372 tracker.record_top_k(row_hash, entry, top_k_proofs);
1373 } else {
1374 tracker.record(row_hash, entry);
1375 }
1376 }
1377 }
1378}
1379
1380fn collect_is_ref_inputs(
1386 rule: &FixpointRulePlan,
1387 clause_index: usize,
1388 delta_batch: &RecordBatch,
1389 row_idx: usize,
1390 registry: &Arc<DerivedScanRegistry>,
1391) -> Vec<ProofTerm> {
1392 let clause = match rule.clauses.get(clause_index) {
1393 Some(c) => c,
1394 None => return vec![],
1395 };
1396
1397 let mut inputs = Vec::new();
1398 let delta_schema = delta_batch.schema();
1399
1400 for binding in &clause.is_ref_bindings {
1401 if binding.negated {
1402 continue;
1403 }
1404 if binding.provenance_join_cols.is_empty() {
1405 continue;
1406 }
1407
1408 let body_values: Vec<(String, ScalarKey)> = binding
1410 .provenance_join_cols
1411 .iter()
1412 .filter_map(|(body_col, _derived_col)| {
1413 let col_idx = delta_schema
1414 .fields()
1415 .iter()
1416 .position(|f| f.name() == body_col)?;
1417 let key = extract_scalar_key(delta_batch, &[col_idx], row_idx);
1418 Some((body_col.clone(), key.into_iter().next()?))
1419 })
1420 .collect();
1421
1422 if body_values.len() != binding.provenance_join_cols.len() {
1423 continue;
1424 }
1425
1426 let entry = match registry.get(binding.derived_scan_index) {
1428 Some(e) => e,
1429 None => continue,
1430 };
1431 let source_batches = entry.data.read();
1432 let source_schema = &entry.schema;
1433
1434 for src_batch in source_batches.iter() {
1436 let all_src_indices: Vec<usize> = (0..src_batch.num_columns()).collect();
1437 for src_row in 0..src_batch.num_rows() {
1438 let matches = binding.provenance_join_cols.iter().enumerate().all(
1439 |(i, (_body_col, derived_col))| {
1440 let src_col_idx = source_schema
1441 .fields()
1442 .iter()
1443 .position(|f| f.name() == derived_col);
1444 match src_col_idx {
1445 Some(idx) => {
1446 let src_key = extract_scalar_key(src_batch, &[idx], src_row);
1447 src_key.first() == Some(&body_values[i].1)
1448 }
1449 None => false,
1450 }
1451 },
1452 );
1453 if matches {
1454 let fact_hash = format!(
1455 "{:?}",
1456 extract_scalar_key(src_batch, &all_src_indices, src_row)
1457 )
1458 .into_bytes();
1459 inputs.push(ProofTerm {
1460 source_rule: binding.rule_name.clone(),
1461 base_fact_id: fact_hash,
1462 });
1463 }
1464 }
1465 }
1466 }
1467
1468 inputs
1469}
1470
1471#[expect(
1490 dead_code,
1491 reason = "Fields accessed via SharedLineageInfo in detect_shared_lineage"
1492)]
1493pub(crate) struct SharedGroupRow {
1494 pub fact_hash: Vec<u8>,
1495 pub lineage: HashSet<Vec<u8>>,
1496}
1497
1498pub(crate) struct SharedLineageInfo {
1500 pub shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>>,
1502}
1503
1504fn fact_hash_key(batch: &RecordBatch, all_indices: &[usize], row_idx: usize) -> Vec<u8> {
1506 format!("{:?}", extract_scalar_key(batch, all_indices, row_idx)).into_bytes()
1507}
1508
1509fn detect_shared_lineage(
1512 rule: &FixpointRulePlan,
1513 pre_fold_facts: &[RecordBatch],
1514 tracker: &Arc<ProvenanceStore>,
1515 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1516) -> Option<SharedLineageInfo> {
1517 use crate::query::df_graph::locy_fold::FoldAggKind;
1518 use uni_locy::{RuntimeWarning, RuntimeWarningCode};
1519
1520 let has_prob_fold = rule
1522 .fold_bindings
1523 .iter()
1524 .any(|fb| matches!(fb.kind, FoldAggKind::Nor | FoldAggKind::Prod));
1525 if !has_prob_fold {
1526 return None;
1527 }
1528
1529 let key_indices = &rule.key_column_indices;
1531 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1532
1533 let mut groups: HashMap<Vec<ScalarKey>, Vec<Vec<u8>>> = HashMap::new();
1534 for batch in pre_fold_facts {
1535 for row_idx in 0..batch.num_rows() {
1536 let key = extract_scalar_key(batch, key_indices, row_idx);
1537 let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
1538 groups.entry(key).or_default().push(fact_hash);
1539 }
1540 }
1541
1542 let mut shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>> = HashMap::new();
1543 let mut any_shared = false;
1544
1545 for (key, fact_hashes) in &groups {
1547 if fact_hashes.len() < 2 {
1548 continue;
1549 }
1550
1551 let mut has_inputs = false;
1553 let mut per_row_bases: Vec<HashSet<Vec<u8>>> = Vec::new();
1554 for fh in fact_hashes {
1555 let bases = compute_lineage(fh, tracker, &mut HashSet::new());
1556 if let Some(entry) = tracker.lookup(fh)
1557 && !entry.support.is_empty()
1558 {
1559 has_inputs = true;
1560 }
1561 per_row_bases.push(bases);
1562 }
1563
1564 let shared_found = if has_inputs {
1565 let mut found = false;
1567 'outer: for i in 0..per_row_bases.len() {
1568 for j in (i + 1)..per_row_bases.len() {
1569 if !per_row_bases[i].is_disjoint(&per_row_bases[j]) {
1570 found = true;
1571 break 'outer;
1572 }
1573 }
1574 }
1575 found
1576 } else {
1577 fact_hashes.iter().any(|fh| {
1580 tracker.lookup(fh).is_some_and(|entry| {
1581 rule.clauses
1582 .get(entry.clause_index)
1583 .is_some_and(|clause| clause.is_ref_bindings.iter().any(|b| !b.negated))
1584 })
1585 })
1586 };
1587
1588 if shared_found {
1589 any_shared = true;
1590 let rows: Vec<SharedGroupRow> = fact_hashes
1592 .iter()
1593 .zip(per_row_bases.into_iter())
1594 .map(|(fh, bases)| SharedGroupRow {
1595 fact_hash: fh.clone(),
1596 lineage: bases,
1597 })
1598 .collect();
1599 shared_groups.insert(key.clone(), rows);
1600 }
1601 }
1602
1603 {
1609 let mut input_to_groups: HashMap<Vec<u8>, HashSet<Vec<ScalarKey>>> = HashMap::new();
1610 for (key, fact_hashes) in &groups {
1611 for fh in fact_hashes {
1612 if let Some(entry) = tracker.lookup(fh) {
1613 for input in &entry.support {
1614 input_to_groups
1615 .entry(input.base_fact_id.clone())
1616 .or_default()
1617 .insert(key.clone());
1618 }
1619 }
1620 }
1621 }
1622 let has_cross_group = input_to_groups.values().any(|g| g.len() > 1);
1623 if has_cross_group && let Ok(mut warnings) = warnings_slot.write() {
1624 let already_warned = warnings.iter().any(|w| {
1625 w.code == RuntimeWarningCode::CrossGroupCorrelationNotExact
1626 && w.rule_name == rule.name
1627 });
1628 if !already_warned {
1629 warnings.push(RuntimeWarning {
1630 code: RuntimeWarningCode::CrossGroupCorrelationNotExact,
1631 message: format!(
1632 "Rule '{}': IS-ref base facts are shared across different KEY \
1633 groups. BDD corrects per-group probabilities but cannot account \
1634 for cross-group correlations.",
1635 rule.name
1636 ),
1637 rule_name: rule.name.clone(),
1638 variable_count: None,
1639 key_group: None,
1640 });
1641 }
1642 }
1643 }
1644
1645 if any_shared {
1646 if let Ok(mut warnings) = warnings_slot.write() {
1647 let already_warned = warnings.iter().any(|w| {
1648 w.code == RuntimeWarningCode::SharedProbabilisticDependency
1649 && w.rule_name == rule.name
1650 });
1651 if !already_warned {
1652 warnings.push(RuntimeWarning {
1653 code: RuntimeWarningCode::SharedProbabilisticDependency,
1654 message: format!(
1655 "Rule '{}' aggregates with MNOR/MPROD but some proof paths \
1656 share intermediate facts, violating the independence assumption. \
1657 Results may overestimate probability.",
1658 rule.name
1659 ),
1660 rule_name: rule.name.clone(),
1661 variable_count: None,
1662 key_group: None,
1663 });
1664 }
1665 }
1666 Some(SharedLineageInfo { shared_groups })
1667 } else {
1668 None
1669 }
1670}
1671
1672pub(crate) fn record_and_detect_lineage_nonrecursive(
1680 rule: &FixpointRulePlan,
1681 tagged_clause_facts: &[(usize, Vec<RecordBatch>)],
1682 tracker: &Arc<ProvenanceStore>,
1683 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1684 registry: &Arc<DerivedScanRegistry>,
1685 top_k_proofs: usize,
1686) -> Option<SharedLineageInfo> {
1687 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1688
1689 let base_probs = if top_k_proofs > 0 {
1691 tracker.base_fact_probs()
1692 } else {
1693 HashMap::new()
1694 };
1695
1696 for (clause_index, batches) in tagged_clause_facts {
1698 for batch in batches {
1699 for row_idx in 0..batch.num_rows() {
1700 let row_hash = fact_hash_key(batch, &all_indices, row_idx);
1701 let fact_row = batch_row_to_value_map(batch, row_idx);
1702
1703 let support = collect_is_ref_inputs(rule, *clause_index, batch, row_idx, registry);
1704
1705 let proof_probability = if top_k_proofs > 0 {
1706 compute_proof_probability(&support, &base_probs)
1707 } else {
1708 None
1709 };
1710
1711 let entry = ProvenanceAnnotation {
1712 rule_name: rule.name.clone(),
1713 clause_index: *clause_index,
1714 support,
1715 along_values: {
1716 let along_names: Vec<String> = rule
1717 .clauses
1718 .get(*clause_index)
1719 .map(|c| c.along_bindings.clone())
1720 .unwrap_or_default();
1721 along_names
1722 .iter()
1723 .filter_map(|name| {
1724 fact_row.get(name).map(|v| (name.clone(), v.clone()))
1725 })
1726 .collect()
1727 },
1728 iteration: 0,
1729 fact_row,
1730 proof_probability,
1731 };
1732 if top_k_proofs > 0 {
1733 tracker.record_top_k(row_hash, entry, top_k_proofs);
1734 } else {
1735 tracker.record(row_hash, entry);
1736 }
1737 }
1738 }
1739 }
1740
1741 let all_facts: Vec<RecordBatch> = tagged_clause_facts
1743 .iter()
1744 .flat_map(|(_, batches)| batches.iter().cloned())
1745 .collect();
1746 detect_shared_lineage(rule, &all_facts, tracker, warnings_slot)
1747}
1748
1749pub(crate) fn apply_exact_wmc(
1757 pre_fold_facts: Vec<RecordBatch>,
1758 rule: &FixpointRulePlan,
1759 shared_info: &SharedLineageInfo,
1760 tracker: &Arc<ProvenanceStore>,
1761 max_bdd_variables: usize,
1762 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1763 approximate_slot: &Arc<StdRwLock<HashMap<String, Vec<String>>>>,
1764) -> DFResult<Vec<RecordBatch>> {
1765 use crate::query::df_graph::locy_bdd::{SemiringOp, weighted_model_count};
1766 use crate::query::df_graph::locy_fold::FoldAggKind;
1767 use uni_locy::{RuntimeWarning, RuntimeWarningCode};
1768
1769 let prob_fold = rule
1771 .fold_bindings
1772 .iter()
1773 .find(|fb| matches!(fb.kind, FoldAggKind::Nor | FoldAggKind::Prod));
1774 let prob_fold = match prob_fold {
1775 Some(f) => f,
1776 None => return Ok(pre_fold_facts),
1777 };
1778 let semiring_op = if matches!(prob_fold.kind, FoldAggKind::Nor) {
1779 SemiringOp::Disjunction
1780 } else {
1781 SemiringOp::Conjunction
1782 };
1783 let prob_col_idx = prob_fold.input_col_index;
1784 let prob_col_name = rule.yield_schema.field(prob_col_idx).name().clone();
1785
1786 let key_indices = &rule.key_column_indices;
1787 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1788
1789 let shared_keys: HashSet<Vec<ScalarKey>> = shared_info.shared_groups.keys().cloned().collect();
1791
1792 struct GroupAccum {
1795 base_facts: Vec<HashSet<Vec<u8>>>,
1796 base_probs: HashMap<Vec<u8>, f64>,
1797 representative: (usize, usize),
1799 row_locations: Vec<(usize, usize)>,
1800 }
1801
1802 let mut group_accums: HashMap<Vec<ScalarKey>, GroupAccum> = HashMap::new();
1803 let mut non_shared_rows: Vec<(usize, usize)> = Vec::new(); for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
1806 for row_idx in 0..batch.num_rows() {
1807 let key = extract_scalar_key(batch, key_indices, row_idx);
1808 if shared_keys.contains(&key) {
1809 let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
1810 let bases = compute_lineage(&fact_hash, tracker, &mut HashSet::new());
1811
1812 let accum = group_accums.entry(key).or_insert_with(|| GroupAccum {
1813 base_facts: Vec::new(),
1814 base_probs: HashMap::new(),
1815 representative: (batch_idx, row_idx),
1816 row_locations: Vec::new(),
1817 });
1818
1819 for bf in &bases {
1821 if !accum.base_probs.contains_key(bf)
1822 && let Some(entry) = tracker.lookup(bf)
1823 && let Some(val) = entry.fact_row.get(&prob_col_name)
1824 && let Some(p) = value_to_f64(val)
1825 {
1826 accum.base_probs.insert(bf.clone(), p);
1827 }
1828 }
1829
1830 accum.base_facts.push(bases);
1831 accum.row_locations.push((batch_idx, row_idx));
1832 } else {
1833 non_shared_rows.push((batch_idx, row_idx));
1834 }
1835 }
1836 }
1837
1838 let mut keep_rows: HashSet<(usize, usize)> = HashSet::new();
1841 let mut overrides: HashMap<(usize, usize), f64> = HashMap::new();
1843
1844 for &loc in &non_shared_rows {
1846 keep_rows.insert(loc);
1847 }
1848
1849 for (key, accum) in &group_accums {
1850 let bdd_result = weighted_model_count(
1851 &accum.base_facts,
1852 &accum.base_probs,
1853 semiring_op,
1854 max_bdd_variables,
1855 );
1856
1857 if bdd_result.approximated {
1858 if let Ok(mut warnings) = warnings_slot.write() {
1860 let key_desc = format!("{key:?}");
1861 let already_warned = warnings.iter().any(|w| {
1862 w.code == RuntimeWarningCode::BddLimitExceeded
1863 && w.rule_name == rule.name
1864 && w.key_group.as_deref() == Some(&key_desc)
1865 });
1866 if !already_warned {
1867 warnings.push(RuntimeWarning {
1868 code: RuntimeWarningCode::BddLimitExceeded,
1869 message: format!(
1870 "Rule '{}': BDD variable limit exceeded ({} > {}). \
1871 Falling back to independence-mode result.",
1872 rule.name, bdd_result.variable_count, max_bdd_variables
1873 ),
1874 rule_name: rule.name.clone(),
1875 variable_count: Some(bdd_result.variable_count),
1876 key_group: Some(key_desc),
1877 });
1878 }
1879 }
1880 if let Ok(mut approx) = approximate_slot.write() {
1881 let key_desc = format!("{key:?}");
1882 approx.entry(rule.name.clone()).or_default().push(key_desc);
1883 }
1884 for &loc in &accum.row_locations {
1886 keep_rows.insert(loc);
1887 }
1888 } else {
1889 keep_rows.insert(accum.representative);
1891 overrides.insert(accum.representative, bdd_result.probability);
1892 }
1893 }
1894
1895 let mut result_batches = Vec::new();
1897 for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
1898 let kept_indices: Vec<usize> = (0..batch.num_rows())
1899 .filter(|&row_idx| keep_rows.contains(&(batch_idx, row_idx)))
1900 .collect();
1901
1902 if kept_indices.is_empty() {
1903 continue;
1904 }
1905
1906 let indices = arrow::array::UInt32Array::from(
1907 kept_indices.iter().map(|&i| i as u32).collect::<Vec<_>>(),
1908 );
1909 let mut columns: Vec<arrow::array::ArrayRef> = batch
1910 .columns()
1911 .iter()
1912 .map(|col| arrow::compute::take(col, &indices, None))
1913 .collect::<Result<Vec<_>, _>>()
1914 .map_err(arrow_err)?;
1915
1916 let override_map: Vec<Option<f64>> = kept_indices
1918 .iter()
1919 .map(|&row_idx| overrides.get(&(batch_idx, row_idx)).copied())
1920 .collect();
1921
1922 if override_map.iter().any(|o| o.is_some()) && prob_col_idx < columns.len() {
1923 let existing_prob = columns[prob_col_idx]
1925 .as_any()
1926 .downcast_ref::<arrow::array::Float64Array>();
1927 let new_values: Vec<f64> = override_map
1928 .iter()
1929 .enumerate()
1930 .map(|(i, ov)| match ov {
1931 Some(p) => *p,
1932 None => existing_prob.map(|arr| arr.value(i)).unwrap_or(0.0),
1933 })
1934 .collect();
1935 columns[prob_col_idx] = Arc::new(arrow::array::Float64Array::from(new_values));
1936 }
1937
1938 let result_batch = RecordBatch::try_new(batch.schema(), columns).map_err(arrow_err)?;
1939 result_batches.push(result_batch);
1940 }
1941
1942 Ok(result_batches)
1943}
1944
1945fn value_to_f64(val: &uni_common::Value) -> Option<f64> {
1947 match val {
1948 uni_common::Value::Float(f) => Some(*f),
1949 uni_common::Value::Int(i) => Some(*i as f64),
1950 _ => None,
1951 }
1952}
1953
1954fn compute_lineage(
1961 fact_hash: &[u8],
1962 tracker: &Arc<ProvenanceStore>,
1963 visited: &mut HashSet<Vec<u8>>,
1964) -> HashSet<Vec<u8>> {
1965 if !visited.insert(fact_hash.to_vec()) {
1966 return HashSet::new(); }
1968
1969 match tracker.lookup(fact_hash) {
1970 Some(entry) if !entry.support.is_empty() => {
1971 let mut bases = HashSet::new();
1972 for input in &entry.support {
1973 let child_bases = compute_lineage(&input.base_fact_id, tracker, visited);
1974 bases.extend(child_bases);
1975 }
1976 bases
1977 }
1978 _ => {
1979 let mut set = HashSet::new();
1981 set.insert(fact_hash.to_vec());
1982 set
1983 }
1984 }
1985}
1986
1987fn find_clause_for_row(
1992 delta_batch: &RecordBatch,
1993 row_idx: usize,
1994 all_indices: &[usize],
1995 clause_candidates: &[Vec<RecordBatch>],
1996) -> usize {
1997 let target_key = extract_scalar_key(delta_batch, all_indices, row_idx);
1998 for (clause_idx, batches) in clause_candidates.iter().enumerate() {
1999 for batch in batches {
2000 if batch.num_columns() != all_indices.len() {
2001 continue;
2002 }
2003 for r in 0..batch.num_rows() {
2004 if extract_scalar_key(batch, all_indices, r) == target_key {
2005 return clause_idx;
2006 }
2007 }
2008 }
2009 }
2010 0
2011}
2012
2013fn batch_row_to_value_map(
2015 batch: &RecordBatch,
2016 row_idx: usize,
2017) -> std::collections::HashMap<String, Value> {
2018 use uni_store::storage::arrow_convert::arrow_to_value;
2019
2020 let schema = batch.schema();
2021 schema
2022 .fields()
2023 .iter()
2024 .enumerate()
2025 .map(|(col_idx, field)| {
2026 let col = batch.column(col_idx);
2027 let val = arrow_to_value(col.as_ref(), row_idx, None);
2028 (field.name().clone(), val)
2029 })
2030 .collect()
2031}
2032
2033pub fn apply_anti_join(
2038 batches: Vec<RecordBatch>,
2039 neg_facts: &[RecordBatch],
2040 left_col: &str,
2041 right_col: &str,
2042) -> datafusion::error::Result<Vec<RecordBatch>> {
2043 use arrow::compute::filter_record_batch;
2044 use arrow_array::{Array as _, BooleanArray, UInt64Array};
2045
2046 let mut banned: std::collections::HashSet<u64> = std::collections::HashSet::new();
2048 for batch in neg_facts {
2049 let Ok(idx) = batch.schema().index_of(right_col) else {
2050 continue;
2051 };
2052 let arr = batch.column(idx);
2053 let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2054 continue;
2055 };
2056 for i in 0..vids.len() {
2057 if !vids.is_null(i) {
2058 banned.insert(vids.value(i));
2059 }
2060 }
2061 }
2062
2063 if banned.is_empty() {
2064 return Ok(batches);
2065 }
2066
2067 let mut result = Vec::new();
2069 for batch in batches {
2070 let Ok(idx) = batch.schema().index_of(left_col) else {
2071 result.push(batch);
2072 continue;
2073 };
2074 let arr = batch.column(idx);
2075 let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2076 result.push(batch);
2077 continue;
2078 };
2079 let keep: Vec<bool> = (0..vids.len())
2080 .map(|i| vids.is_null(i) || !banned.contains(&vids.value(i)))
2081 .collect();
2082 let keep_arr = BooleanArray::from(keep);
2083 let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
2084 if filtered.num_rows() > 0 {
2085 result.push(filtered);
2086 }
2087 }
2088 Ok(result)
2089}
2090
2091pub fn apply_prob_complement(
2100 batches: Vec<RecordBatch>,
2101 neg_facts: &[RecordBatch],
2102 left_col: &str,
2103 right_col: &str,
2104 prob_col: &str,
2105 complement_col_name: &str,
2106) -> datafusion::error::Result<Vec<RecordBatch>> {
2107 use arrow_array::{Array as _, Float64Array, UInt64Array};
2108
2109 let mut prob_map: std::collections::HashMap<u64, f64> = std::collections::HashMap::new();
2111 for batch in neg_facts {
2112 let Ok(vid_idx) = batch.schema().index_of(right_col) else {
2113 continue;
2114 };
2115 let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
2116 continue;
2117 };
2118 let Some(vids) = batch.column(vid_idx).as_any().downcast_ref::<UInt64Array>() else {
2119 continue;
2120 };
2121 let prob_arr = batch.column(prob_idx);
2122 let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
2123 for i in 0..vids.len() {
2124 if !vids.is_null(i) {
2125 let p = probs
2126 .and_then(|arr| {
2127 if arr.is_null(i) {
2128 None
2129 } else {
2130 Some(arr.value(i))
2131 }
2132 })
2133 .unwrap_or(0.0);
2134 prob_map
2137 .entry(vids.value(i))
2138 .and_modify(|existing| {
2139 *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
2140 })
2141 .or_insert(p);
2142 }
2143 }
2144 }
2145
2146 let mut result = Vec::new();
2148 for batch in batches {
2149 let Ok(idx) = batch.schema().index_of(left_col) else {
2150 result.push(batch);
2151 continue;
2152 };
2153 let Some(vids) = batch.column(idx).as_any().downcast_ref::<UInt64Array>() else {
2154 result.push(batch);
2155 continue;
2156 };
2157
2158 let complements: Vec<f64> = (0..vids.len())
2160 .map(|i| {
2161 if vids.is_null(i) {
2162 1.0
2163 } else {
2164 let p = prob_map.get(&vids.value(i)).copied().unwrap_or(0.0);
2165 1.0 - p
2166 }
2167 })
2168 .collect();
2169
2170 let complement_arr = Float64Array::from(complements);
2171
2172 let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
2174 columns.push(std::sync::Arc::new(complement_arr));
2175
2176 let mut fields: Vec<std::sync::Arc<arrow_schema::Field>> =
2177 batch.schema().fields().iter().cloned().collect();
2178 fields.push(std::sync::Arc::new(arrow_schema::Field::new(
2179 complement_col_name,
2180 arrow_schema::DataType::Float64,
2181 true,
2182 )));
2183
2184 let new_schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
2185 let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
2186 result.push(new_batch);
2187 }
2188 Ok(result)
2189}
2190
2191pub fn apply_prob_complement_composite(
2198 batches: Vec<RecordBatch>,
2199 neg_facts: &[RecordBatch],
2200 join_cols: &[(String, String)],
2201 prob_col: &str,
2202 complement_col_name: &str,
2203) -> datafusion::error::Result<Vec<RecordBatch>> {
2204 use arrow_array::{Array as _, Float64Array, UInt64Array};
2205
2206 let mut prob_map: HashMap<Vec<u64>, f64> = HashMap::new();
2208 for batch in neg_facts {
2209 let right_indices: Vec<usize> = join_cols
2210 .iter()
2211 .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
2212 .collect();
2213 if right_indices.len() != join_cols.len() {
2214 continue;
2215 }
2216 let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
2217 continue;
2218 };
2219 let prob_arr = batch.column(prob_idx);
2220 let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
2221 for row in 0..batch.num_rows() {
2222 let mut key = Vec::with_capacity(right_indices.len());
2223 let mut valid = true;
2224 for &ci in &right_indices {
2225 let col = batch.column(ci);
2226 if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
2227 if vids.is_null(row) {
2228 valid = false;
2229 break;
2230 }
2231 key.push(vids.value(row));
2232 } else {
2233 valid = false;
2234 break;
2235 }
2236 }
2237 if !valid {
2238 continue;
2239 }
2240 let p = probs
2241 .and_then(|arr| {
2242 if arr.is_null(row) {
2243 None
2244 } else {
2245 Some(arr.value(row))
2246 }
2247 })
2248 .unwrap_or(0.0);
2249 prob_map
2251 .entry(key)
2252 .and_modify(|existing| {
2253 *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
2254 })
2255 .or_insert(p);
2256 }
2257 }
2258
2259 let mut result = Vec::new();
2261 for batch in batches {
2262 let left_indices: Vec<usize> = join_cols
2263 .iter()
2264 .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
2265 .collect();
2266 if left_indices.len() != join_cols.len() {
2267 result.push(batch);
2268 continue;
2269 }
2270 let all_u64 = left_indices.iter().all(|&ci| {
2271 batch
2272 .column(ci)
2273 .as_any()
2274 .downcast_ref::<UInt64Array>()
2275 .is_some()
2276 });
2277 if !all_u64 {
2278 result.push(batch);
2279 continue;
2280 }
2281
2282 let complements: Vec<f64> = (0..batch.num_rows())
2283 .map(|row| {
2284 let mut key = Vec::with_capacity(left_indices.len());
2285 for &ci in &left_indices {
2286 let vids = batch
2287 .column(ci)
2288 .as_any()
2289 .downcast_ref::<UInt64Array>()
2290 .unwrap();
2291 if vids.is_null(row) {
2292 return 1.0;
2293 }
2294 key.push(vids.value(row));
2295 }
2296 let p = prob_map.get(&key).copied().unwrap_or(0.0);
2297 1.0 - p
2298 })
2299 .collect();
2300
2301 let complement_arr = Float64Array::from(complements);
2302 let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
2303 columns.push(Arc::new(complement_arr));
2304
2305 let mut fields: Vec<Arc<arrow_schema::Field>> =
2306 batch.schema().fields().iter().cloned().collect();
2307 fields.push(Arc::new(arrow_schema::Field::new(
2308 complement_col_name,
2309 arrow_schema::DataType::Float64,
2310 true,
2311 )));
2312
2313 let new_schema = Arc::new(arrow_schema::Schema::new(fields));
2314 let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
2315 result.push(new_batch);
2316 }
2317 Ok(result)
2318}
2319
2320pub fn apply_anti_join_composite(
2326 batches: Vec<RecordBatch>,
2327 neg_facts: &[RecordBatch],
2328 join_cols: &[(String, String)],
2329) -> datafusion::error::Result<Vec<RecordBatch>> {
2330 use arrow::compute::filter_record_batch;
2331 use arrow_array::{Array as _, BooleanArray, UInt64Array};
2332
2333 let mut banned: HashSet<Vec<u64>> = HashSet::new();
2335 for batch in neg_facts {
2336 let right_indices: Vec<usize> = join_cols
2337 .iter()
2338 .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
2339 .collect();
2340 if right_indices.len() != join_cols.len() {
2341 continue;
2342 }
2343 for row in 0..batch.num_rows() {
2344 let mut key = Vec::with_capacity(right_indices.len());
2345 let mut valid = true;
2346 for &ci in &right_indices {
2347 let col = batch.column(ci);
2348 if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
2349 if vids.is_null(row) {
2350 valid = false;
2351 break;
2352 }
2353 key.push(vids.value(row));
2354 } else {
2355 valid = false;
2356 break;
2357 }
2358 }
2359 if valid {
2360 banned.insert(key);
2361 }
2362 }
2363 }
2364
2365 if banned.is_empty() {
2366 return Ok(batches);
2367 }
2368
2369 let mut result = Vec::new();
2371 for batch in batches {
2372 let left_indices: Vec<usize> = join_cols
2373 .iter()
2374 .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
2375 .collect();
2376 if left_indices.len() != join_cols.len() {
2377 result.push(batch);
2378 continue;
2379 }
2380 let all_u64 = left_indices.iter().all(|&ci| {
2381 batch
2382 .column(ci)
2383 .as_any()
2384 .downcast_ref::<UInt64Array>()
2385 .is_some()
2386 });
2387 if !all_u64 {
2388 result.push(batch);
2389 continue;
2390 }
2391
2392 let keep: Vec<bool> = (0..batch.num_rows())
2393 .map(|row| {
2394 let mut key = Vec::with_capacity(left_indices.len());
2395 for &ci in &left_indices {
2396 let vids = batch
2397 .column(ci)
2398 .as_any()
2399 .downcast_ref::<UInt64Array>()
2400 .unwrap();
2401 if vids.is_null(row) {
2402 return true; }
2404 key.push(vids.value(row));
2405 }
2406 !banned.contains(&key)
2407 })
2408 .collect();
2409 let keep_arr = BooleanArray::from(keep);
2410 let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
2411 if filtered.num_rows() > 0 {
2412 result.push(filtered);
2413 }
2414 }
2415 Ok(result)
2416}
2417
2418pub fn multiply_prob_factors(
2429 batches: Vec<RecordBatch>,
2430 prob_col: Option<&str>,
2431 complement_cols: &[String],
2432) -> datafusion::error::Result<Vec<RecordBatch>> {
2433 use arrow_array::{Array as _, Float64Array};
2434
2435 let mut result = Vec::with_capacity(batches.len());
2436
2437 for batch in batches {
2438 if batch.num_rows() == 0 {
2439 let keep: Vec<usize> = batch
2441 .schema()
2442 .fields()
2443 .iter()
2444 .enumerate()
2445 .filter(|(_, f)| !complement_cols.contains(f.name()))
2446 .map(|(i, _)| i)
2447 .collect();
2448 let fields: Vec<_> = keep
2449 .iter()
2450 .map(|&i| batch.schema().field(i).clone())
2451 .collect();
2452 let cols: Vec<_> = keep.iter().map(|&i| batch.column(i).clone()).collect();
2453 let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
2454 result.push(
2455 RecordBatch::try_new(schema, cols).map_err(|e| {
2456 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
2457 })?,
2458 );
2459 continue;
2460 }
2461
2462 let num_rows = batch.num_rows();
2463
2464 let mut combined = vec![1.0f64; num_rows];
2466 for col_name in complement_cols {
2467 if let Ok(idx) = batch.schema().index_of(col_name) {
2468 let arr = batch
2469 .column(idx)
2470 .as_any()
2471 .downcast_ref::<Float64Array>()
2472 .ok_or_else(|| {
2473 datafusion::error::DataFusionError::Internal(format!(
2474 "Expected Float64 for complement column {col_name}"
2475 ))
2476 })?;
2477 for (i, val) in combined.iter_mut().enumerate().take(num_rows) {
2478 if !arr.is_null(i) {
2479 *val *= arr.value(i);
2480 }
2481 }
2482 }
2483 }
2484
2485 let final_prob: Vec<f64> = if let Some(prob_name) = prob_col {
2487 if let Ok(idx) = batch.schema().index_of(prob_name) {
2488 let arr = batch
2489 .column(idx)
2490 .as_any()
2491 .downcast_ref::<Float64Array>()
2492 .ok_or_else(|| {
2493 datafusion::error::DataFusionError::Internal(format!(
2494 "Expected Float64 for PROB column {prob_name}"
2495 ))
2496 })?;
2497 (0..num_rows)
2498 .map(|i| {
2499 if arr.is_null(i) {
2500 combined[i]
2501 } else {
2502 arr.value(i) * combined[i]
2503 }
2504 })
2505 .collect()
2506 } else {
2507 combined
2508 }
2509 } else {
2510 combined
2511 };
2512
2513 let new_prob_array: arrow_array::ArrayRef =
2514 std::sync::Arc::new(Float64Array::from(final_prob));
2515
2516 let mut fields = Vec::new();
2518 let mut columns = Vec::new();
2519
2520 for (idx, field) in batch.schema().fields().iter().enumerate() {
2521 if complement_cols.contains(field.name()) {
2522 continue;
2523 }
2524 if prob_col.is_some_and(|p| field.name() == p) {
2525 fields.push(field.clone());
2526 columns.push(new_prob_array.clone());
2527 } else {
2528 fields.push(field.clone());
2529 columns.push(batch.column(idx).clone());
2530 }
2531 }
2532
2533 let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
2534 result.push(RecordBatch::try_new(schema, columns).map_err(arrow_err)?);
2535 }
2536
2537 Ok(result)
2538}
2539
2540fn update_derived_scan_handles(
2545 registry: &DerivedScanRegistry,
2546 states: &[FixpointState],
2547 current_rule_idx: usize,
2548 rules: &[FixpointRulePlan],
2549) {
2550 let current_rule_name = &rules[current_rule_idx].name;
2551
2552 for entry in ®istry.entries {
2553 let source_state_idx = rules.iter().position(|r| r.name == entry.rule_name);
2555 let Some(source_idx) = source_state_idx else {
2556 continue;
2557 };
2558
2559 let is_self = entry.rule_name == *current_rule_name;
2560 let data = if is_self {
2561 states[source_idx].all_delta().to_vec()
2563 } else {
2564 states[source_idx].all_facts().to_vec()
2566 };
2567
2568 let data = if data.is_empty() || data.iter().all(|b| b.num_rows() == 0) {
2570 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
2571 } else {
2572 data
2573 };
2574
2575 let mut guard = entry.data.write();
2576 *guard = data;
2577 }
2578}
2579
2580pub struct DerivedScanExec {
2590 data: Arc<RwLock<Vec<RecordBatch>>>,
2591 schema: SchemaRef,
2592 properties: PlanProperties,
2593}
2594
2595impl DerivedScanExec {
2596 pub fn new(data: Arc<RwLock<Vec<RecordBatch>>>, schema: SchemaRef) -> Self {
2597 let properties = compute_plan_properties(Arc::clone(&schema));
2598 Self {
2599 data,
2600 schema,
2601 properties,
2602 }
2603 }
2604}
2605
2606impl fmt::Debug for DerivedScanExec {
2607 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2608 f.debug_struct("DerivedScanExec")
2609 .field("schema", &self.schema)
2610 .finish()
2611 }
2612}
2613
2614impl DisplayAs for DerivedScanExec {
2615 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2616 write!(f, "DerivedScanExec")
2617 }
2618}
2619
2620impl ExecutionPlan for DerivedScanExec {
2621 fn name(&self) -> &str {
2622 "DerivedScanExec"
2623 }
2624 fn as_any(&self) -> &dyn Any {
2625 self
2626 }
2627 fn schema(&self) -> SchemaRef {
2628 Arc::clone(&self.schema)
2629 }
2630 fn properties(&self) -> &PlanProperties {
2631 &self.properties
2632 }
2633 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2634 vec![]
2635 }
2636 fn with_new_children(
2637 self: Arc<Self>,
2638 _children: Vec<Arc<dyn ExecutionPlan>>,
2639 ) -> DFResult<Arc<dyn ExecutionPlan>> {
2640 Ok(self)
2641 }
2642 fn execute(
2643 &self,
2644 _partition: usize,
2645 _context: Arc<TaskContext>,
2646 ) -> DFResult<SendableRecordBatchStream> {
2647 let batches = {
2648 let guard = self.data.read();
2649 if guard.is_empty() {
2650 vec![RecordBatch::new_empty(Arc::clone(&self.schema))]
2651 } else {
2652 guard.clone()
2653 }
2654 };
2655 Ok(Box::pin(MemoryStream::try_new(
2656 batches,
2657 Arc::clone(&self.schema),
2658 None,
2659 )?))
2660 }
2661}
2662
2663struct InMemoryExec {
2672 batches: Vec<RecordBatch>,
2673 schema: SchemaRef,
2674 properties: PlanProperties,
2675}
2676
2677impl InMemoryExec {
2678 fn new(batches: Vec<RecordBatch>, schema: SchemaRef) -> Self {
2679 let properties = compute_plan_properties(Arc::clone(&schema));
2680 Self {
2681 batches,
2682 schema,
2683 properties,
2684 }
2685 }
2686}
2687
2688impl fmt::Debug for InMemoryExec {
2689 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2690 f.debug_struct("InMemoryExec")
2691 .field("num_batches", &self.batches.len())
2692 .field("schema", &self.schema)
2693 .finish()
2694 }
2695}
2696
2697impl DisplayAs for InMemoryExec {
2698 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2699 write!(f, "InMemoryExec: batches={}", self.batches.len())
2700 }
2701}
2702
2703impl ExecutionPlan for InMemoryExec {
2704 fn name(&self) -> &str {
2705 "InMemoryExec"
2706 }
2707 fn as_any(&self) -> &dyn Any {
2708 self
2709 }
2710 fn schema(&self) -> SchemaRef {
2711 Arc::clone(&self.schema)
2712 }
2713 fn properties(&self) -> &PlanProperties {
2714 &self.properties
2715 }
2716 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2717 vec![]
2718 }
2719 fn with_new_children(
2720 self: Arc<Self>,
2721 _children: Vec<Arc<dyn ExecutionPlan>>,
2722 ) -> DFResult<Arc<dyn ExecutionPlan>> {
2723 Ok(self)
2724 }
2725 fn execute(
2726 &self,
2727 _partition: usize,
2728 _context: Arc<TaskContext>,
2729 ) -> DFResult<SendableRecordBatchStream> {
2730 Ok(Box::pin(MemoryStream::try_new(
2731 self.batches.clone(),
2732 Arc::clone(&self.schema),
2733 None,
2734 )?))
2735 }
2736}
2737
2738fn apply_having_filter(
2748 batches: Vec<RecordBatch>,
2749 having_exprs: &[Expr],
2750 schema: &SchemaRef,
2751) -> DFResult<Vec<RecordBatch>> {
2752 use arrow::compute::{and, filter_record_batch};
2753 use arrow_array::BooleanArray;
2754 use datafusion::common::DFSchema;
2755 use datafusion::logical_expr::LogicalPlanBuilder;
2756 use datafusion::optimizer::AnalyzerRule;
2757 use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
2758 use datafusion::physical_expr::create_physical_expr;
2759 use datafusion::prelude::SessionContext;
2760
2761 if batches.is_empty() {
2762 return Ok(batches);
2763 }
2764
2765 let df_schema = DFSchema::try_from(schema.as_ref().clone()).map_err(|e| {
2767 datafusion::common::DataFusionError::Internal(format!("HAVING schema conversion: {e}"))
2768 })?;
2769
2770 let ctx = SessionContext::new();
2771 let state = ctx.state();
2772 let config = state.config_options().clone();
2773 let props = state.execution_props();
2774
2775 let physical_exprs: Vec<Arc<dyn datafusion::physical_expr::PhysicalExpr>> = having_exprs
2781 .iter()
2782 .map(|expr| {
2783 let df_expr = crate::query::df_expr::cypher_expr_to_df(expr, None).map_err(|e| {
2784 datafusion::common::DataFusionError::Internal(format!(
2785 "HAVING expression conversion: {e}"
2786 ))
2787 })?;
2788
2789 let empty = datafusion::logical_expr::LogicalPlan::EmptyRelation(
2793 datafusion::logical_expr::EmptyRelation {
2794 produce_one_row: false,
2795 schema: Arc::new(df_schema.clone()),
2796 },
2797 );
2798 let filter_plan = LogicalPlanBuilder::from(empty)
2799 .filter(df_expr.clone())?
2800 .build()?;
2801 let coerced_expr = match TypeCoercion::new().analyze(filter_plan, &config) {
2802 Ok(datafusion::logical_expr::LogicalPlan::Filter(f)) => f.predicate,
2803 _ => df_expr,
2804 };
2805
2806 create_physical_expr(&coerced_expr, &df_schema, props)
2807 })
2808 .collect::<DFResult<Vec<_>>>()?;
2809
2810 let mut result = Vec::new();
2811 for batch in batches {
2812 let mut mask: Option<BooleanArray> = None;
2814 for phys_expr in &physical_exprs {
2815 let value = phys_expr.evaluate(&batch)?;
2816 let arr = value.into_array(batch.num_rows())?;
2817 let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().ok_or_else(|| {
2818 datafusion::common::DataFusionError::Internal(
2819 "HAVING condition must evaluate to boolean".into(),
2820 )
2821 })?;
2822 mask = Some(match mask {
2823 None => bool_arr.clone(),
2824 Some(prev) => and(&prev, bool_arr).map_err(arrow_err)?,
2825 });
2826 }
2827 if let Some(ref m) = mask {
2828 let filtered = filter_record_batch(&batch, m).map_err(arrow_err)?;
2829 if filtered.num_rows() > 0 {
2830 result.push(filtered);
2831 }
2832 } else {
2833 result.push(batch);
2834 }
2835 }
2836 Ok(result)
2837}
2838
2839pub(crate) async fn apply_post_fixpoint_chain(
2841 facts: Vec<RecordBatch>,
2842 rule: &FixpointRulePlan,
2843 task_ctx: &Arc<TaskContext>,
2844 strict_probability_domain: bool,
2845 probability_epsilon: f64,
2846) -> DFResult<Vec<RecordBatch>> {
2847 if !rule.has_fold && !rule.has_best_by && !rule.has_priority && rule.having.is_empty() {
2848 return Ok(facts);
2849 }
2850
2851 let schema = facts
2856 .iter()
2857 .find(|b| b.num_rows() > 0)
2858 .map(|b| b.schema())
2859 .unwrap_or_else(|| Arc::clone(&rule.yield_schema));
2860 let input: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(facts, schema));
2861
2862 let current: Arc<dyn ExecutionPlan> = if rule.has_priority {
2866 let priority_schema = input.schema();
2867 let priority_idx = priority_schema.index_of("__priority").map_err(|_| {
2868 datafusion::common::DataFusionError::Internal(
2869 "PRIORITY rule missing __priority column".to_string(),
2870 )
2871 })?;
2872 Arc::new(PriorityExec::new(
2873 input,
2874 rule.key_column_indices.clone(),
2875 priority_idx,
2876 ))
2877 } else {
2878 input
2879 };
2880
2881 let current: Arc<dyn ExecutionPlan> = if rule.has_fold && !rule.fold_bindings.is_empty() {
2883 Arc::new(FoldExec::new(
2884 current,
2885 rule.key_column_indices.clone(),
2886 rule.fold_bindings.clone(),
2887 strict_probability_domain,
2888 probability_epsilon,
2889 ))
2890 } else {
2891 current
2892 };
2893
2894 let current: Arc<dyn ExecutionPlan> = if !rule.having.is_empty() {
2896 let batches = collect_all_partitions(¤t, Arc::clone(task_ctx)).await?;
2897 let filtered = apply_having_filter(batches, &rule.having, ¤t.schema())?;
2898 if filtered.is_empty() {
2899 return Ok(filtered);
2900 }
2901 Arc::new(InMemoryExec::new(filtered, Arc::clone(¤t.schema())))
2902 } else {
2903 current
2904 };
2905
2906 let current: Arc<dyn ExecutionPlan> = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
2908 Arc::new(BestByExec::new(
2909 current,
2910 rule.key_column_indices.clone(),
2911 rule.best_by_criteria.clone(),
2912 rule.deterministic,
2913 ))
2914 } else {
2915 current
2916 };
2917
2918 collect_all_partitions(¤t, Arc::clone(task_ctx)).await
2919}
2920
2921pub struct FixpointExec {
2930 rules: Vec<FixpointRulePlan>,
2931 max_iterations: usize,
2932 timeout: Duration,
2933 graph_ctx: Arc<GraphExecutionContext>,
2934 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
2935 storage: Arc<StorageManager>,
2936 schema_info: Arc<UniSchema>,
2937 params: HashMap<String, Value>,
2938 derived_scan_registry: Arc<DerivedScanRegistry>,
2939 output_schema: SchemaRef,
2940 properties: PlanProperties,
2941 metrics: ExecutionPlanMetricsSet,
2942 max_derived_bytes: usize,
2943 derivation_tracker: Option<Arc<ProvenanceStore>>,
2945 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
2947 strict_probability_domain: bool,
2948 probability_epsilon: f64,
2949 exact_probability: bool,
2950 max_bdd_variables: usize,
2951 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
2953 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
2955 top_k_proofs: usize,
2957}
2958
2959impl fmt::Debug for FixpointExec {
2960 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2961 f.debug_struct("FixpointExec")
2962 .field("rules_count", &self.rules.len())
2963 .field("max_iterations", &self.max_iterations)
2964 .field("timeout", &self.timeout)
2965 .field("output_schema", &self.output_schema)
2966 .field("max_derived_bytes", &self.max_derived_bytes)
2967 .finish_non_exhaustive()
2968 }
2969}
2970
2971impl FixpointExec {
2972 #[expect(
2974 clippy::too_many_arguments,
2975 reason = "FixpointExec configuration needs all context"
2976 )]
2977 pub fn new(
2978 rules: Vec<FixpointRulePlan>,
2979 max_iterations: usize,
2980 timeout: Duration,
2981 graph_ctx: Arc<GraphExecutionContext>,
2982 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
2983 storage: Arc<StorageManager>,
2984 schema_info: Arc<UniSchema>,
2985 params: HashMap<String, Value>,
2986 derived_scan_registry: Arc<DerivedScanRegistry>,
2987 output_schema: SchemaRef,
2988 max_derived_bytes: usize,
2989 derivation_tracker: Option<Arc<ProvenanceStore>>,
2990 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
2991 strict_probability_domain: bool,
2992 probability_epsilon: f64,
2993 exact_probability: bool,
2994 max_bdd_variables: usize,
2995 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
2996 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
2997 top_k_proofs: usize,
2998 ) -> Self {
2999 let properties = compute_plan_properties(Arc::clone(&output_schema));
3000 Self {
3001 rules,
3002 max_iterations,
3003 timeout,
3004 graph_ctx,
3005 session_ctx,
3006 storage,
3007 schema_info,
3008 params,
3009 derived_scan_registry,
3010 output_schema,
3011 properties,
3012 metrics: ExecutionPlanMetricsSet::new(),
3013 max_derived_bytes,
3014 derivation_tracker,
3015 iteration_counts,
3016 strict_probability_domain,
3017 probability_epsilon,
3018 exact_probability,
3019 max_bdd_variables,
3020 warnings_slot,
3021 approximate_slot,
3022 top_k_proofs,
3023 }
3024 }
3025
3026 pub fn iteration_counts(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
3028 Arc::clone(&self.iteration_counts)
3029 }
3030}
3031
3032impl DisplayAs for FixpointExec {
3033 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
3034 write!(
3035 f,
3036 "FixpointExec: rules=[{}], max_iter={}, timeout={:?}",
3037 self.rules
3038 .iter()
3039 .map(|r| r.name.as_str())
3040 .collect::<Vec<_>>()
3041 .join(", "),
3042 self.max_iterations,
3043 self.timeout,
3044 )
3045 }
3046}
3047
3048impl ExecutionPlan for FixpointExec {
3049 fn name(&self) -> &str {
3050 "FixpointExec"
3051 }
3052
3053 fn as_any(&self) -> &dyn Any {
3054 self
3055 }
3056
3057 fn schema(&self) -> SchemaRef {
3058 Arc::clone(&self.output_schema)
3059 }
3060
3061 fn properties(&self) -> &PlanProperties {
3062 &self.properties
3063 }
3064
3065 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
3066 vec![]
3068 }
3069
3070 fn with_new_children(
3071 self: Arc<Self>,
3072 children: Vec<Arc<dyn ExecutionPlan>>,
3073 ) -> DFResult<Arc<dyn ExecutionPlan>> {
3074 if !children.is_empty() {
3075 return Err(datafusion::error::DataFusionError::Plan(
3076 "FixpointExec has no children".to_string(),
3077 ));
3078 }
3079 Ok(self)
3080 }
3081
3082 fn execute(
3083 &self,
3084 partition: usize,
3085 _context: Arc<TaskContext>,
3086 ) -> DFResult<SendableRecordBatchStream> {
3087 let metrics = BaselineMetrics::new(&self.metrics, partition);
3088
3089 let rules = self
3091 .rules
3092 .iter()
3093 .map(|r| {
3094 FixpointRulePlan {
3098 name: r.name.clone(),
3099 clauses: r
3100 .clauses
3101 .iter()
3102 .map(|c| FixpointClausePlan {
3103 body_logical: c.body_logical.clone(),
3104 is_ref_bindings: c.is_ref_bindings.clone(),
3105 priority: c.priority,
3106 along_bindings: c.along_bindings.clone(),
3107 })
3108 .collect(),
3109 yield_schema: Arc::clone(&r.yield_schema),
3110 key_column_indices: r.key_column_indices.clone(),
3111 priority: r.priority,
3112 has_fold: r.has_fold,
3113 fold_bindings: r.fold_bindings.clone(),
3114 having: r.having.clone(),
3115 has_best_by: r.has_best_by,
3116 best_by_criteria: r.best_by_criteria.clone(),
3117 has_priority: r.has_priority,
3118 deterministic: r.deterministic,
3119 prob_column_name: r.prob_column_name.clone(),
3120 }
3121 })
3122 .collect();
3123
3124 let max_iterations = self.max_iterations;
3125 let timeout = self.timeout;
3126 let graph_ctx = Arc::clone(&self.graph_ctx);
3127 let session_ctx = Arc::clone(&self.session_ctx);
3128 let storage = Arc::clone(&self.storage);
3129 let schema_info = Arc::clone(&self.schema_info);
3130 let params = self.params.clone();
3131 let registry = Arc::clone(&self.derived_scan_registry);
3132 let output_schema = Arc::clone(&self.output_schema);
3133 let max_derived_bytes = self.max_derived_bytes;
3134 let derivation_tracker = self.derivation_tracker.clone();
3135 let iteration_counts = Arc::clone(&self.iteration_counts);
3136 let strict_probability_domain = self.strict_probability_domain;
3137 let probability_epsilon = self.probability_epsilon;
3138 let exact_probability = self.exact_probability;
3139 let max_bdd_variables = self.max_bdd_variables;
3140 let warnings_slot = Arc::clone(&self.warnings_slot);
3141 let approximate_slot = Arc::clone(&self.approximate_slot);
3142 let top_k_proofs = self.top_k_proofs;
3143
3144 let fut = async move {
3145 run_fixpoint_loop(
3146 rules,
3147 max_iterations,
3148 timeout,
3149 graph_ctx,
3150 session_ctx,
3151 storage,
3152 schema_info,
3153 params,
3154 registry,
3155 output_schema,
3156 max_derived_bytes,
3157 derivation_tracker,
3158 iteration_counts,
3159 strict_probability_domain,
3160 probability_epsilon,
3161 exact_probability,
3162 max_bdd_variables,
3163 warnings_slot,
3164 approximate_slot,
3165 top_k_proofs,
3166 )
3167 .await
3168 };
3169
3170 Ok(Box::pin(FixpointStream {
3171 state: FixpointStreamState::Running(Box::pin(fut)),
3172 schema: Arc::clone(&self.output_schema),
3173 metrics,
3174 }))
3175 }
3176
3177 fn metrics(&self) -> Option<MetricsSet> {
3178 Some(self.metrics.clone_inner())
3179 }
3180}
3181
3182enum FixpointStreamState {
3187 Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
3189 Emitting(Vec<RecordBatch>, usize),
3191 Done,
3193}
3194
3195struct FixpointStream {
3196 state: FixpointStreamState,
3197 schema: SchemaRef,
3198 metrics: BaselineMetrics,
3199}
3200
3201impl Stream for FixpointStream {
3202 type Item = DFResult<RecordBatch>;
3203
3204 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
3205 let this = self.get_mut();
3206 loop {
3207 match &mut this.state {
3208 FixpointStreamState::Running(fut) => match fut.as_mut().poll(cx) {
3209 Poll::Ready(Ok(batches)) => {
3210 if batches.is_empty() {
3211 this.state = FixpointStreamState::Done;
3212 return Poll::Ready(None);
3213 }
3214 this.state = FixpointStreamState::Emitting(batches, 0);
3215 }
3217 Poll::Ready(Err(e)) => {
3218 this.state = FixpointStreamState::Done;
3219 return Poll::Ready(Some(Err(e)));
3220 }
3221 Poll::Pending => return Poll::Pending,
3222 },
3223 FixpointStreamState::Emitting(batches, idx) => {
3224 if *idx >= batches.len() {
3225 this.state = FixpointStreamState::Done;
3226 return Poll::Ready(None);
3227 }
3228 let batch = batches[*idx].clone();
3229 *idx += 1;
3230 this.metrics.record_output(batch.num_rows());
3231 return Poll::Ready(Some(Ok(batch)));
3232 }
3233 FixpointStreamState::Done => return Poll::Ready(None),
3234 }
3235 }
3236 }
3237}
3238
3239impl RecordBatchStream for FixpointStream {
3240 fn schema(&self) -> SchemaRef {
3241 Arc::clone(&self.schema)
3242 }
3243}
3244
3245#[cfg(test)]
3250mod tests {
3251 use super::*;
3252 use arrow_array::{Float64Array, Int64Array, StringArray};
3253 use arrow_schema::{DataType, Field, Schema};
3254
3255 fn test_schema() -> SchemaRef {
3256 Arc::new(Schema::new(vec![
3257 Field::new("name", DataType::Utf8, true),
3258 Field::new("value", DataType::Int64, true),
3259 ]))
3260 }
3261
3262 fn make_batch(names: &[&str], values: &[i64]) -> RecordBatch {
3263 RecordBatch::try_new(
3264 test_schema(),
3265 vec![
3266 Arc::new(StringArray::from(
3267 names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
3268 )),
3269 Arc::new(Int64Array::from(values.to_vec())),
3270 ],
3271 )
3272 .unwrap()
3273 }
3274
3275 #[tokio::test]
3278 async fn test_fixpoint_state_empty_facts_adds_all() {
3279 let schema = test_schema();
3280 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3281
3282 let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
3283 let changed = state.merge_delta(vec![batch], None).await.unwrap();
3284
3285 assert!(changed);
3286 assert_eq!(state.all_facts().len(), 1);
3287 assert_eq!(state.all_facts()[0].num_rows(), 3);
3288 assert_eq!(state.all_delta().len(), 1);
3289 assert_eq!(state.all_delta()[0].num_rows(), 3);
3290 }
3291
3292 #[tokio::test]
3293 async fn test_fixpoint_state_exact_duplicates_excluded() {
3294 let schema = test_schema();
3295 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3296
3297 let batch1 = make_batch(&["a", "b"], &[1, 2]);
3298 state.merge_delta(vec![batch1], None).await.unwrap();
3299
3300 let batch2 = make_batch(&["a", "b"], &[1, 2]);
3302 let changed = state.merge_delta(vec![batch2], None).await.unwrap();
3303 assert!(!changed);
3304 assert!(
3305 state.all_delta().is_empty() || state.all_delta().iter().all(|b| b.num_rows() == 0)
3306 );
3307 }
3308
3309 #[tokio::test]
3310 async fn test_fixpoint_state_partial_overlap() {
3311 let schema = test_schema();
3312 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3313
3314 let batch1 = make_batch(&["a", "b"], &[1, 2]);
3315 state.merge_delta(vec![batch1], None).await.unwrap();
3316
3317 let batch2 = make_batch(&["a", "c"], &[1, 3]);
3319 let changed = state.merge_delta(vec![batch2], None).await.unwrap();
3320 assert!(changed);
3321
3322 let delta_rows: usize = state.all_delta().iter().map(|b| b.num_rows()).sum();
3324 assert_eq!(delta_rows, 1);
3325
3326 let total_rows: usize = state.all_facts().iter().map(|b| b.num_rows()).sum();
3328 assert_eq!(total_rows, 3);
3329 }
3330
3331 #[tokio::test]
3332 async fn test_fixpoint_state_convergence() {
3333 let schema = test_schema();
3334 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3335
3336 let batch = make_batch(&["a"], &[1]);
3337 state.merge_delta(vec![batch], None).await.unwrap();
3338
3339 let changed = state.merge_delta(vec![], None).await.unwrap();
3341 assert!(!changed);
3342 assert!(state.is_converged());
3343 }
3344
3345 #[test]
3348 fn test_row_dedup_persistent_across_calls() {
3349 let schema = test_schema();
3352 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
3353
3354 let batch1 = make_batch(&["a", "b"], &[1, 2]);
3355 let delta1 = rd.compute_delta(&[batch1], &schema).unwrap();
3356 let rows1: usize = delta1.iter().map(|b| b.num_rows()).sum();
3358 assert_eq!(rows1, 2);
3359
3360 let batch2 = make_batch(&["a", "b"], &[1, 2]);
3362 let delta2 = rd.compute_delta(&[batch2], &schema).unwrap();
3363 let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
3364 assert_eq!(rows2, 0);
3365
3366 let batch3 = make_batch(&["a", "c"], &[1, 3]);
3368 let delta3 = rd.compute_delta(&[batch3], &schema).unwrap();
3369 let rows3: usize = delta3.iter().map(|b| b.num_rows()).sum();
3370 assert_eq!(rows3, 1);
3371 }
3372
3373 #[test]
3374 fn test_row_dedup_null_handling() {
3375 use arrow_array::StringArray;
3376 use arrow_schema::{DataType, Field, Schema};
3377
3378 let schema: SchemaRef = Arc::new(Schema::new(vec![
3379 Field::new("a", DataType::Utf8, true),
3380 Field::new("b", DataType::Int64, true),
3381 ]));
3382 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
3383
3384 let batch_nulls = RecordBatch::try_new(
3386 Arc::clone(&schema),
3387 vec![
3388 Arc::new(StringArray::from(vec![None::<&str>, None::<&str>])),
3389 Arc::new(arrow_array::Int64Array::from(vec![1i64, 1i64])),
3390 ],
3391 )
3392 .unwrap();
3393 let delta = rd.compute_delta(&[batch_nulls], &schema).unwrap();
3394 let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
3395 assert_eq!(rows, 1, "two identical NULL rows should be deduped to one");
3396
3397 let batch_diff = RecordBatch::try_new(
3399 Arc::clone(&schema),
3400 vec![
3401 Arc::new(StringArray::from(vec![None::<&str>])),
3402 Arc::new(arrow_array::Int64Array::from(vec![2i64])),
3403 ],
3404 )
3405 .unwrap();
3406 let delta2 = rd.compute_delta(&[batch_diff], &schema).unwrap();
3407 let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
3408 assert_eq!(rows2, 1, "(NULL, 2) is distinct from (NULL, 1)");
3409 }
3410
3411 #[test]
3412 fn test_row_dedup_within_candidate_dedup() {
3413 let schema = test_schema();
3415 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
3416
3417 let batch = make_batch(&["a", "a", "b"], &[1, 1, 2]);
3419 let delta = rd.compute_delta(&[batch], &schema).unwrap();
3420 let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
3421 assert_eq!(rows, 2, "within-batch dup should be collapsed: a:1, b:2");
3422 }
3423
3424 #[test]
3427 fn test_round_float_columns_near_duplicates() {
3428 let schema = Arc::new(Schema::new(vec![
3429 Field::new("name", DataType::Utf8, true),
3430 Field::new("dist", DataType::Float64, true),
3431 ]));
3432 let batch = RecordBatch::try_new(
3433 Arc::clone(&schema),
3434 vec![
3435 Arc::new(StringArray::from(vec![Some("a"), Some("a")])),
3436 Arc::new(Float64Array::from(vec![1.0000000000001, 1.0000000000002])),
3437 ],
3438 )
3439 .unwrap();
3440
3441 let rounded = round_float_columns(&[batch]);
3442 assert_eq!(rounded.len(), 1);
3443 let col = rounded[0]
3444 .column(1)
3445 .as_any()
3446 .downcast_ref::<Float64Array>()
3447 .unwrap();
3448 assert_eq!(col.value(0), col.value(1));
3450 }
3451
3452 #[test]
3455 fn test_registry_write_read_round_trip() {
3456 let schema = test_schema();
3457 let data = Arc::new(RwLock::new(Vec::new()));
3458 let mut reg = DerivedScanRegistry::new();
3459 reg.add(DerivedScanEntry {
3460 scan_index: 0,
3461 rule_name: "reachable".into(),
3462 is_self_ref: true,
3463 data: Arc::clone(&data),
3464 schema: Arc::clone(&schema),
3465 });
3466
3467 let batch = make_batch(&["x"], &[42]);
3468 reg.write_data(0, vec![batch.clone()]);
3469
3470 let entry = reg.get(0).unwrap();
3471 let guard = entry.data.read();
3472 assert_eq!(guard.len(), 1);
3473 assert_eq!(guard[0].num_rows(), 1);
3474 }
3475
3476 #[test]
3477 fn test_registry_entries_for_rule() {
3478 let schema = test_schema();
3479 let mut reg = DerivedScanRegistry::new();
3480 reg.add(DerivedScanEntry {
3481 scan_index: 0,
3482 rule_name: "r1".into(),
3483 is_self_ref: true,
3484 data: Arc::new(RwLock::new(Vec::new())),
3485 schema: Arc::clone(&schema),
3486 });
3487 reg.add(DerivedScanEntry {
3488 scan_index: 1,
3489 rule_name: "r2".into(),
3490 is_self_ref: false,
3491 data: Arc::new(RwLock::new(Vec::new())),
3492 schema: Arc::clone(&schema),
3493 });
3494 reg.add(DerivedScanEntry {
3495 scan_index: 2,
3496 rule_name: "r1".into(),
3497 is_self_ref: false,
3498 data: Arc::new(RwLock::new(Vec::new())),
3499 schema: Arc::clone(&schema),
3500 });
3501
3502 assert_eq!(reg.entries_for_rule("r1").len(), 2);
3503 assert_eq!(reg.entries_for_rule("r2").len(), 1);
3504 assert_eq!(reg.entries_for_rule("r3").len(), 0);
3505 }
3506
3507 #[test]
3510 fn test_monotonic_agg_update_and_stability() {
3511 use crate::query::df_graph::locy_fold::FoldAggKind;
3512
3513 let bindings = vec![MonotonicFoldBinding {
3514 fold_name: "total".into(),
3515 kind: FoldAggKind::Sum,
3516 input_col_index: 1,
3517 input_col_name: None,
3518 }];
3519 let mut agg = MonotonicAggState::new(bindings);
3520
3521 let batch = make_batch(&["a"], &[10]);
3523 agg.snapshot();
3524 let changed = agg.update(&[0], &[batch], false).unwrap();
3525 assert!(changed);
3526 assert!(!agg.is_stable()); agg.snapshot();
3530 let changed = agg.update(&[0], &[], false).unwrap();
3531 assert!(!changed);
3532 assert!(agg.is_stable());
3533 }
3534
3535 #[tokio::test]
3538 async fn test_memory_limit_exceeded() {
3539 let schema = test_schema();
3540 let mut state = FixpointState::new("test".into(), schema, vec![0], 1, None, false);
3542
3543 let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
3544 let result = state.merge_delta(vec![batch], None).await;
3545 assert!(result.is_err());
3546 let err = result.unwrap_err().to_string();
3547 assert!(err.contains("memory limit"), "Error was: {}", err);
3548 }
3549
3550 #[tokio::test]
3553 async fn test_fixpoint_stream_emitting() {
3554 use futures::StreamExt;
3555
3556 let schema = test_schema();
3557 let batch1 = make_batch(&["a"], &[1]);
3558 let batch2 = make_batch(&["b"], &[2]);
3559
3560 let metrics = ExecutionPlanMetricsSet::new();
3561 let baseline = BaselineMetrics::new(&metrics, 0);
3562
3563 let mut stream = FixpointStream {
3564 state: FixpointStreamState::Emitting(vec![batch1, batch2], 0),
3565 schema,
3566 metrics: baseline,
3567 };
3568
3569 let stream = Pin::new(&mut stream);
3570 let batches: Vec<RecordBatch> = stream.filter_map(|r| async { r.ok() }).collect().await;
3571
3572 assert_eq!(batches.len(), 2);
3573 assert_eq!(batches[0].num_rows(), 1);
3574 assert_eq!(batches[1].num_rows(), 1);
3575 }
3576
3577 fn make_f64_batch(names: &[&str], values: &[f64]) -> RecordBatch {
3580 let schema = Arc::new(Schema::new(vec![
3581 Field::new("name", DataType::Utf8, true),
3582 Field::new("value", DataType::Float64, true),
3583 ]));
3584 RecordBatch::try_new(
3585 schema,
3586 vec![
3587 Arc::new(StringArray::from(
3588 names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
3589 )),
3590 Arc::new(Float64Array::from(values.to_vec())),
3591 ],
3592 )
3593 .unwrap()
3594 }
3595
3596 fn make_nor_binding() -> Vec<MonotonicFoldBinding> {
3597 use crate::query::df_graph::locy_fold::FoldAggKind;
3598 vec![MonotonicFoldBinding {
3599 fold_name: "prob".into(),
3600 kind: FoldAggKind::Nor,
3601 input_col_index: 1,
3602 input_col_name: None,
3603 }]
3604 }
3605
3606 fn make_prod_binding() -> Vec<MonotonicFoldBinding> {
3607 use crate::query::df_graph::locy_fold::FoldAggKind;
3608 vec![MonotonicFoldBinding {
3609 fold_name: "prob".into(),
3610 kind: FoldAggKind::Prod,
3611 input_col_index: 1,
3612 input_col_name: None,
3613 }]
3614 }
3615
3616 fn acc_key(name: &str) -> (Vec<ScalarKey>, String) {
3617 (vec![ScalarKey::Utf8(name.to_string())], "prob".to_string())
3618 }
3619
3620 #[test]
3621 fn test_monotonic_nor_first_update() {
3622 let mut agg = MonotonicAggState::new(make_nor_binding());
3623 let batch = make_f64_batch(&["a"], &[0.3]);
3624 let changed = agg.update(&[0], &[batch], false).unwrap();
3625 assert!(changed);
3626 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3627 assert!((val - 0.3).abs() < 1e-10, "expected 0.3, got {}", val);
3628 }
3629
3630 #[test]
3631 fn test_monotonic_nor_two_updates() {
3632 let mut agg = MonotonicAggState::new(make_nor_binding());
3634 let batch1 = make_f64_batch(&["a"], &[0.3]);
3635 agg.update(&[0], &[batch1], false).unwrap();
3636 let batch2 = make_f64_batch(&["a"], &[0.5]);
3637 agg.update(&[0], &[batch2], false).unwrap();
3638 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3639 assert!((val - 0.65).abs() < 1e-10, "expected 0.65, got {}", val);
3640 }
3641
3642 #[test]
3643 fn test_monotonic_prod_first_update() {
3644 let mut agg = MonotonicAggState::new(make_prod_binding());
3645 let batch = make_f64_batch(&["a"], &[0.6]);
3646 let changed = agg.update(&[0], &[batch], false).unwrap();
3647 assert!(changed);
3648 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3649 assert!((val - 0.6).abs() < 1e-10, "expected 0.6, got {}", val);
3650 }
3651
3652 #[test]
3653 fn test_monotonic_prod_two_updates() {
3654 let mut agg = MonotonicAggState::new(make_prod_binding());
3656 let batch1 = make_f64_batch(&["a"], &[0.6]);
3657 agg.update(&[0], &[batch1], false).unwrap();
3658 let batch2 = make_f64_batch(&["a"], &[0.8]);
3659 agg.update(&[0], &[batch2], false).unwrap();
3660 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3661 assert!((val - 0.48).abs() < 1e-10, "expected 0.48, got {}", val);
3662 }
3663
3664 #[test]
3665 fn test_monotonic_nor_stability() {
3666 let mut agg = MonotonicAggState::new(make_nor_binding());
3667 let batch = make_f64_batch(&["a"], &[0.3]);
3668 agg.update(&[0], &[batch], false).unwrap();
3669 agg.snapshot();
3670 let changed = agg.update(&[0], &[], false).unwrap();
3671 assert!(!changed);
3672 assert!(agg.is_stable());
3673 }
3674
3675 #[test]
3676 fn test_monotonic_prod_stability() {
3677 let mut agg = MonotonicAggState::new(make_prod_binding());
3678 let batch = make_f64_batch(&["a"], &[0.6]);
3679 agg.update(&[0], &[batch], false).unwrap();
3680 agg.snapshot();
3681 let changed = agg.update(&[0], &[], false).unwrap();
3682 assert!(!changed);
3683 assert!(agg.is_stable());
3684 }
3685
3686 #[test]
3687 fn test_monotonic_nor_multi_group() {
3688 let mut agg = MonotonicAggState::new(make_nor_binding());
3690 let batch1 = make_f64_batch(&["a", "b"], &[0.3, 0.5]);
3691 agg.update(&[0], &[batch1], false).unwrap();
3692 let batch2 = make_f64_batch(&["a", "b"], &[0.5, 0.2]);
3693 agg.update(&[0], &[batch2], false).unwrap();
3694
3695 let val_a = agg.get_accumulator(&acc_key("a")).unwrap();
3696 let val_b = agg.get_accumulator(&acc_key("b")).unwrap();
3697 assert!(
3698 (val_a - 0.65).abs() < 1e-10,
3699 "expected a=0.65, got {}",
3700 val_a
3701 );
3702 assert!((val_b - 0.6).abs() < 1e-10, "expected b=0.6, got {}", val_b);
3703 }
3704
3705 #[test]
3706 fn test_monotonic_prod_zero_absorbing() {
3707 let mut agg = MonotonicAggState::new(make_prod_binding());
3709 let batch1 = make_f64_batch(&["a"], &[0.5]);
3710 agg.update(&[0], &[batch1], false).unwrap();
3711 let batch2 = make_f64_batch(&["a"], &[0.0]);
3712 agg.update(&[0], &[batch2], false).unwrap();
3713
3714 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3715 assert!((val - 0.0).abs() < 1e-10, "expected 0.0, got {}", val);
3716
3717 agg.snapshot();
3719 let batch3 = make_f64_batch(&["a"], &[0.5]);
3720 let changed = agg.update(&[0], &[batch3], false).unwrap();
3721 assert!(!changed);
3722 assert!(agg.is_stable());
3723 }
3724
3725 #[test]
3726 fn test_monotonic_nor_clamping() {
3727 let mut agg = MonotonicAggState::new(make_nor_binding());
3729 let batch = make_f64_batch(&["a"], &[1.5]);
3730 agg.update(&[0], &[batch], false).unwrap();
3731 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3732 assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
3733 }
3734
3735 #[test]
3736 fn test_monotonic_nor_absorbing() {
3737 let mut agg = MonotonicAggState::new(make_nor_binding());
3739 let batch1 = make_f64_batch(&["a"], &[0.3]);
3740 agg.update(&[0], &[batch1], false).unwrap();
3741 let batch2 = make_f64_batch(&["a"], &[1.0]);
3742 agg.update(&[0], &[batch2], false).unwrap();
3743 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3744 assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
3745 }
3746
3747 #[test]
3750 fn test_monotonic_agg_strict_nor_rejects() {
3751 let mut agg = MonotonicAggState::new(make_nor_binding());
3752 let batch = make_f64_batch(&["a"], &[1.5]);
3753 let result = agg.update(&[0], &[batch], true);
3754 assert!(result.is_err());
3755 let err = result.unwrap_err().to_string();
3756 assert!(
3757 err.contains("strict_probability_domain"),
3758 "Expected strict error, got: {}",
3759 err
3760 );
3761 }
3762
3763 #[test]
3764 fn test_monotonic_agg_strict_prod_rejects() {
3765 let mut agg = MonotonicAggState::new(make_prod_binding());
3766 let batch = make_f64_batch(&["a"], &[2.0]);
3767 let result = agg.update(&[0], &[batch], true);
3768 assert!(result.is_err());
3769 let err = result.unwrap_err().to_string();
3770 assert!(
3771 err.contains("strict_probability_domain"),
3772 "Expected strict error, got: {}",
3773 err
3774 );
3775 }
3776
3777 #[test]
3778 fn test_monotonic_agg_strict_accepts_valid() {
3779 let mut agg = MonotonicAggState::new(make_nor_binding());
3780 let batch = make_f64_batch(&["a"], &[0.5]);
3781 let result = agg.update(&[0], &[batch], true);
3782 assert!(result.is_ok());
3783 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3784 assert!((val - 0.5).abs() < 1e-10, "expected 0.5, got {}", val);
3785 }
3786
3787 fn make_vid_prob_batch(vids: &[u64], probs: &[f64]) -> RecordBatch {
3790 use arrow_array::UInt64Array;
3791 let schema = Arc::new(Schema::new(vec![
3792 Field::new("vid", DataType::UInt64, true),
3793 Field::new("prob", DataType::Float64, true),
3794 ]));
3795 RecordBatch::try_new(
3796 schema,
3797 vec![
3798 Arc::new(UInt64Array::from(vids.to_vec())),
3799 Arc::new(Float64Array::from(probs.to_vec())),
3800 ],
3801 )
3802 .unwrap()
3803 }
3804
3805 #[test]
3806 fn test_prob_complement_basic() {
3807 let body = make_vid_prob_batch(&[1, 2], &[0.9, 0.8]);
3809 let neg = make_vid_prob_batch(&[1], &[0.7]);
3810 let join_cols = vec![("vid".to_string(), "vid".to_string())];
3811 let result = apply_prob_complement_composite(
3812 vec![body],
3813 &[neg],
3814 &join_cols,
3815 "prob",
3816 "__complement_0",
3817 )
3818 .unwrap();
3819 assert_eq!(result.len(), 1);
3820 let batch = &result[0];
3821 let complement = batch
3822 .column_by_name("__complement_0")
3823 .unwrap()
3824 .as_any()
3825 .downcast_ref::<Float64Array>()
3826 .unwrap();
3827 assert!(
3829 (complement.value(0) - 0.3).abs() < 1e-10,
3830 "expected 0.3, got {}",
3831 complement.value(0)
3832 );
3833 assert!(
3835 (complement.value(1) - 1.0).abs() < 1e-10,
3836 "expected 1.0, got {}",
3837 complement.value(1)
3838 );
3839 }
3840
3841 #[test]
3842 fn test_prob_complement_noisy_or_duplicates() {
3843 let body = make_vid_prob_batch(&[1], &[0.9]);
3847 let neg = make_vid_prob_batch(&[1, 1], &[0.3, 0.5]);
3848 let join_cols = vec![("vid".to_string(), "vid".to_string())];
3849 let result = apply_prob_complement_composite(
3850 vec![body],
3851 &[neg],
3852 &join_cols,
3853 "prob",
3854 "__complement_0",
3855 )
3856 .unwrap();
3857 let batch = &result[0];
3858 let complement = batch
3859 .column_by_name("__complement_0")
3860 .unwrap()
3861 .as_any()
3862 .downcast_ref::<Float64Array>()
3863 .unwrap();
3864 assert!(
3865 (complement.value(0) - 0.35).abs() < 1e-10,
3866 "expected 0.35, got {}",
3867 complement.value(0)
3868 );
3869 }
3870
3871 #[test]
3872 fn test_prob_complement_empty_neg() {
3873 let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
3875 let join_cols = vec![("vid".to_string(), "vid".to_string())];
3876 let result =
3877 apply_prob_complement_composite(vec![body], &[], &join_cols, "prob", "__complement_0")
3878 .unwrap();
3879 let batch = &result[0];
3880 let complement = batch
3881 .column_by_name("__complement_0")
3882 .unwrap()
3883 .as_any()
3884 .downcast_ref::<Float64Array>()
3885 .unwrap();
3886 for i in 0..2 {
3887 assert!(
3888 (complement.value(i) - 1.0).abs() < 1e-10,
3889 "row {}: expected 1.0, got {}",
3890 i,
3891 complement.value(i)
3892 );
3893 }
3894 }
3895
3896 #[test]
3897 fn test_anti_join_basic() {
3898 use arrow_array::UInt64Array;
3900 let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
3901 let neg = make_vid_prob_batch(&[2], &[0.0]);
3902 let join_cols = vec![("vid".to_string(), "vid".to_string())];
3903 let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
3904 assert_eq!(result.len(), 1);
3905 let batch = &result[0];
3906 assert_eq!(batch.num_rows(), 2);
3907 let vids = batch
3908 .column_by_name("vid")
3909 .unwrap()
3910 .as_any()
3911 .downcast_ref::<UInt64Array>()
3912 .unwrap();
3913 assert_eq!(vids.value(0), 1);
3914 assert_eq!(vids.value(1), 3);
3915 }
3916
3917 #[test]
3918 fn test_anti_join_empty_neg() {
3919 let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
3921 let join_cols = vec![("vid".to_string(), "vid".to_string())];
3922 let result = apply_anti_join_composite(vec![body], &[], &join_cols).unwrap();
3923 assert_eq!(result.len(), 1);
3924 assert_eq!(result[0].num_rows(), 3);
3925 }
3926
3927 #[test]
3928 fn test_anti_join_all_excluded() {
3929 let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
3931 let neg = make_vid_prob_batch(&[1, 2], &[0.0, 0.0]);
3932 let join_cols = vec![("vid".to_string(), "vid".to_string())];
3933 let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
3934 let total: usize = result.iter().map(|b| b.num_rows()).sum();
3935 assert_eq!(total, 0);
3936 }
3937
3938 #[test]
3939 fn test_multiply_prob_single_complement() {
3940 let body = make_vid_prob_batch(&[1], &[0.8]);
3942 let complement_arr = Float64Array::from(vec![0.5]);
3944 let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
3945 cols.push(Arc::new(complement_arr));
3946 let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
3947 fields.push(Arc::new(Field::new(
3948 "__complement_0",
3949 DataType::Float64,
3950 true,
3951 )));
3952 let schema = Arc::new(Schema::new(fields));
3953 let batch = RecordBatch::try_new(schema, cols).unwrap();
3954
3955 let result =
3956 multiply_prob_factors(vec![batch], Some("prob"), &["__complement_0".to_string()])
3957 .unwrap();
3958 assert_eq!(result.len(), 1);
3959 let out = &result[0];
3960 assert!(out.column_by_name("__complement_0").is_none());
3962 let prob = out
3963 .column_by_name("prob")
3964 .unwrap()
3965 .as_any()
3966 .downcast_ref::<Float64Array>()
3967 .unwrap();
3968 assert!(
3969 (prob.value(0) - 0.4).abs() < 1e-10,
3970 "expected 0.4, got {}",
3971 prob.value(0)
3972 );
3973 }
3974
3975 #[test]
3976 fn test_multiply_prob_multiple_complements() {
3977 let body = make_vid_prob_batch(&[1], &[0.8]);
3979 let c1 = Float64Array::from(vec![0.5]);
3980 let c2 = Float64Array::from(vec![0.6]);
3981 let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
3982 cols.push(Arc::new(c1));
3983 cols.push(Arc::new(c2));
3984 let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
3985 fields.push(Arc::new(Field::new("__c1", DataType::Float64, true)));
3986 fields.push(Arc::new(Field::new("__c2", DataType::Float64, true)));
3987 let schema = Arc::new(Schema::new(fields));
3988 let batch = RecordBatch::try_new(schema, cols).unwrap();
3989
3990 let result = multiply_prob_factors(
3991 vec![batch],
3992 Some("prob"),
3993 &["__c1".to_string(), "__c2".to_string()],
3994 )
3995 .unwrap();
3996 let out = &result[0];
3997 assert!(out.column_by_name("__c1").is_none());
3998 assert!(out.column_by_name("__c2").is_none());
3999 let prob = out
4000 .column_by_name("prob")
4001 .unwrap()
4002 .as_any()
4003 .downcast_ref::<Float64Array>()
4004 .unwrap();
4005 assert!(
4006 (prob.value(0) - 0.24).abs() < 1e-10,
4007 "expected 0.24, got {}",
4008 prob.value(0)
4009 );
4010 }
4011
4012 #[test]
4013 fn test_multiply_prob_no_prob_column() {
4014 use arrow_array::UInt64Array;
4016 let schema = Arc::new(Schema::new(vec![
4017 Field::new("vid", DataType::UInt64, true),
4018 Field::new("__c1", DataType::Float64, true),
4019 ]));
4020 let batch = RecordBatch::try_new(
4021 schema,
4022 vec![
4023 Arc::new(UInt64Array::from(vec![1u64])),
4024 Arc::new(Float64Array::from(vec![0.7])),
4025 ],
4026 )
4027 .unwrap();
4028
4029 let result = multiply_prob_factors(vec![batch], None, &["__c1".to_string()]).unwrap();
4030 let out = &result[0];
4031 assert!(out.column_by_name("__c1").is_none());
4033 assert_eq!(out.num_columns(), 1);
4035 }
4036}