1use crate::query::df_graph::GraphExecutionContext;
10use crate::query::df_graph::common::{
11 ScalarKey, arrow_err, collect_all_partitions, compute_plan_properties, execute_subplan,
12 extract_scalar_key,
13};
14use crate::query::df_graph::locy_best_by::{BestByExec, SortCriterion};
15use crate::query::df_graph::locy_errors::LocyRuntimeError;
16use crate::query::df_graph::locy_explain::{
17 ProofTerm, ProvenanceAnnotation, ProvenanceStore, compute_proof_probability,
18};
19use crate::query::df_graph::locy_fold::{FoldBinding, FoldExec};
20use crate::query::df_graph::locy_priority::PriorityExec;
21use crate::query::planner::LogicalPlan;
22use arrow_array::RecordBatch;
23use arrow_row::{RowConverter, SortField};
24use arrow_schema::SchemaRef;
25use datafusion::common::JoinType;
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
29use datafusion::physical_plan::memory::MemoryStream;
30use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
31use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
32use futures::Stream;
33use parking_lot::RwLock;
34use std::any::Any;
35use std::collections::{HashMap, HashSet};
36use std::fmt;
37use std::pin::Pin;
38use std::sync::{Arc, RwLock as StdRwLock};
39use std::task::{Context, Poll};
40use std::time::{Duration, Instant};
41use uni_common::Value;
42use uni_common::core::schema::Schema as UniSchema;
43use uni_locy::RuntimeWarning;
44use uni_store::storage::manager::StorageManager;
45
46#[derive(Debug)]
56pub struct DerivedScanEntry {
57 pub scan_index: usize,
59 pub rule_name: String,
61 pub is_self_ref: bool,
63 pub data: Arc<RwLock<Vec<RecordBatch>>>,
65 pub schema: SchemaRef,
67}
68
69#[derive(Debug, Default)]
76pub struct DerivedScanRegistry {
77 entries: Vec<DerivedScanEntry>,
78}
79
80impl DerivedScanRegistry {
81 pub fn new() -> Self {
83 Self::default()
84 }
85
86 pub fn add(&mut self, entry: DerivedScanEntry) {
88 self.entries.push(entry);
89 }
90
91 pub fn get(&self, scan_index: usize) -> Option<&DerivedScanEntry> {
93 self.entries.iter().find(|e| e.scan_index == scan_index)
94 }
95
96 pub fn write_data(&self, scan_index: usize, batches: Vec<RecordBatch>) {
98 if let Some(entry) = self.get(scan_index) {
99 let mut guard = entry.data.write();
100 *guard = batches;
101 }
102 }
103
104 pub fn entries_for_rule(&self, rule_name: &str) -> Vec<&DerivedScanEntry> {
106 self.entries
107 .iter()
108 .filter(|e| e.rule_name == rule_name)
109 .collect()
110 }
111}
112
113#[derive(Debug, Clone)]
119pub struct MonotonicFoldBinding {
120 pub fold_name: String,
121 pub kind: crate::query::df_graph::locy_fold::FoldAggKind,
122 pub input_col_index: usize,
123}
124
125#[derive(Debug)]
131pub struct MonotonicAggState {
132 accumulators: HashMap<(Vec<ScalarKey>, String), f64>,
134 prev_snapshot: HashMap<(Vec<ScalarKey>, String), f64>,
136 bindings: Vec<MonotonicFoldBinding>,
138}
139
140impl MonotonicAggState {
141 pub fn new(bindings: Vec<MonotonicFoldBinding>) -> Self {
143 Self {
144 accumulators: HashMap::new(),
145 prev_snapshot: HashMap::new(),
146 bindings,
147 }
148 }
149
150 pub fn update(
156 &mut self,
157 key_indices: &[usize],
158 delta_batches: &[RecordBatch],
159 strict: bool,
160 ) -> DFResult<bool> {
161 use crate::query::df_graph::locy_fold::FoldAggKind;
162
163 let mut changed = false;
164 for batch in delta_batches {
165 for row_idx in 0..batch.num_rows() {
166 let group_key = extract_scalar_key(batch, key_indices, row_idx);
167 for binding in &self.bindings {
168 let col = batch.column(binding.input_col_index);
169 let val = extract_f64(col.as_ref(), row_idx);
170 if let Some(val) = val {
171 let map_key = (group_key.clone(), binding.fold_name.clone());
172 let entry = self
173 .accumulators
174 .entry(map_key)
175 .or_insert(binding.kind.identity().unwrap_or(0.0));
176 let old = *entry;
177 match binding.kind {
178 FoldAggKind::Sum | FoldAggKind::Count => *entry += val,
179 FoldAggKind::Max => {
180 if val > *entry {
181 *entry = val;
182 }
183 }
184 FoldAggKind::Min => {
185 if val < *entry {
186 *entry = val;
187 }
188 }
189 FoldAggKind::Nor => {
190 if strict && !(0.0..=1.0).contains(&val) {
191 return Err(datafusion::error::DataFusionError::Execution(
192 format!(
193 "strict_probability_domain: MNOR input {val} is outside [0, 1]"
194 ),
195 ));
196 }
197 if !strict && !(0.0..=1.0).contains(&val) {
198 tracing::warn!(
199 "MNOR input {val} outside [0,1], clamped to {}",
200 val.clamp(0.0, 1.0)
201 );
202 }
203 let p = val.clamp(0.0, 1.0);
204 *entry = 1.0 - (1.0 - *entry) * (1.0 - p);
205 }
206 FoldAggKind::Prod => {
207 if strict && !(0.0..=1.0).contains(&val) {
208 return Err(datafusion::error::DataFusionError::Execution(
209 format!(
210 "strict_probability_domain: MPROD input {val} is outside [0, 1]"
211 ),
212 ));
213 }
214 if !strict && !(0.0..=1.0).contains(&val) {
215 tracing::warn!(
216 "MPROD input {val} outside [0,1], clamped to {}",
217 val.clamp(0.0, 1.0)
218 );
219 }
220 let p = val.clamp(0.0, 1.0);
221 *entry *= p;
222 }
223 _ => {}
224 }
225 if (*entry - old).abs() > f64::EPSILON {
226 changed = true;
227 }
228 }
229 }
230 }
231 }
232 Ok(changed)
233 }
234
235 pub fn snapshot(&mut self) {
237 self.prev_snapshot = self.accumulators.clone();
238 }
239
240 pub fn is_stable(&self) -> bool {
242 if self.accumulators.len() != self.prev_snapshot.len() {
243 return false;
244 }
245 for (key, val) in &self.accumulators {
246 match self.prev_snapshot.get(key) {
247 Some(prev) if (*val - *prev).abs() <= f64::EPSILON => {}
248 _ => return false,
249 }
250 }
251 true
252 }
253
254 #[cfg(test)]
256 pub(crate) fn get_accumulator(&self, key: &(Vec<ScalarKey>, String)) -> Option<f64> {
257 self.accumulators.get(key).copied()
258 }
259}
260
261fn extract_f64(col: &dyn arrow_array::Array, row_idx: usize) -> Option<f64> {
263 if col.is_null(row_idx) {
264 return None;
265 }
266 if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Float64Array>() {
267 Some(arr.value(row_idx))
268 } else {
269 col.as_any()
270 .downcast_ref::<arrow_array::Int64Array>()
271 .map(|arr| arr.value(row_idx) as f64)
272 }
273}
274
275struct RowDedupState {
285 converter: RowConverter,
286 seen: HashSet<Box<[u8]>>,
287}
288
289impl RowDedupState {
290 fn try_new(schema: &SchemaRef) -> Option<Self> {
295 let fields: Vec<SortField> = schema
296 .fields()
297 .iter()
298 .map(|f| SortField::new(f.data_type().clone()))
299 .collect();
300 match RowConverter::new(fields) {
301 Ok(converter) => Some(Self {
302 converter,
303 seen: HashSet::new(),
304 }),
305 Err(e) => {
306 tracing::warn!(
307 "RowDedupState: RowConverter unsupported for schema, falling back to legacy dedup: {}",
308 e
309 );
310 None
311 }
312 }
313 }
314
315 fn ingest_existing(&mut self, facts: &[RecordBatch], _schema: &SchemaRef) {
320 self.seen.clear();
321 for batch in facts {
322 if batch.num_rows() == 0 {
323 continue;
324 }
325 let arrays: Vec<_> = batch.columns().to_vec();
326 if let Ok(rows) = self.converter.convert_columns(&arrays) {
327 for row_idx in 0..batch.num_rows() {
328 let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
329 self.seen.insert(row_bytes);
330 }
331 }
332 }
333 }
334
335 fn compute_delta(
341 &mut self,
342 candidates: &[RecordBatch],
343 schema: &SchemaRef,
344 ) -> DFResult<Vec<RecordBatch>> {
345 let mut delta_batches = Vec::new();
346 for batch in candidates {
347 if batch.num_rows() == 0 {
348 continue;
349 }
350
351 let arrays: Vec<_> = batch.columns().to_vec();
353 let rows = self.converter.convert_columns(&arrays).map_err(arrow_err)?;
354
355 let mut keep = Vec::with_capacity(batch.num_rows());
357 for row_idx in 0..batch.num_rows() {
358 let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
359 keep.push(self.seen.insert(row_bytes));
360 }
361
362 let keep_mask = arrow_array::BooleanArray::from(keep);
363 let new_cols = batch
364 .columns()
365 .iter()
366 .map(|col| {
367 arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
368 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
369 })
370 })
371 .collect::<DFResult<Vec<_>>>()?;
372
373 if new_cols.first().is_some_and(|c| !c.is_empty()) {
374 let filtered = RecordBatch::try_new(Arc::clone(schema), new_cols).map_err(|e| {
375 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
376 })?;
377 delta_batches.push(filtered);
378 }
379 }
380 Ok(delta_batches)
381 }
382}
383
384pub struct FixpointState {
394 rule_name: String,
395 facts: Vec<RecordBatch>,
396 delta: Vec<RecordBatch>,
397 schema: SchemaRef,
398 key_column_indices: Vec<usize>,
399 all_column_indices: Vec<usize>,
401 facts_bytes: usize,
403 max_derived_bytes: usize,
405 monotonic_agg: Option<MonotonicAggState>,
407 row_dedup: Option<RowDedupState>,
409 strict_probability_domain: bool,
411}
412
413impl FixpointState {
414 pub fn new(
416 rule_name: String,
417 schema: SchemaRef,
418 key_column_indices: Vec<usize>,
419 max_derived_bytes: usize,
420 monotonic_agg: Option<MonotonicAggState>,
421 strict_probability_domain: bool,
422 ) -> Self {
423 let num_cols = schema.fields().len();
424 let row_dedup = RowDedupState::try_new(&schema);
425 Self {
426 rule_name,
427 facts: Vec::new(),
428 delta: Vec::new(),
429 schema,
430 key_column_indices,
431 all_column_indices: (0..num_cols).collect(),
432 facts_bytes: 0,
433 max_derived_bytes,
434 monotonic_agg,
435 row_dedup,
436 strict_probability_domain,
437 }
438 }
439
440 fn reconcile_schema(&mut self, actual_schema: &SchemaRef) {
447 if self.schema.fields() != actual_schema.fields() {
448 tracing::debug!(
449 rule = %self.rule_name,
450 "Reconciling fixpoint schema from physical plan output",
451 );
452 self.schema = Arc::clone(actual_schema);
453 self.row_dedup = RowDedupState::try_new(&self.schema);
454 }
455 }
456
457 pub async fn merge_delta(
461 &mut self,
462 candidates: Vec<RecordBatch>,
463 task_ctx: Option<Arc<TaskContext>>,
464 ) -> DFResult<bool> {
465 if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
466 self.delta.clear();
467 return Ok(false);
468 }
469
470 if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
474 self.reconcile_schema(&first.schema());
475 }
476
477 let candidates = round_float_columns(&candidates);
479
480 let delta = self.compute_delta(&candidates, task_ctx.as_ref()).await?;
482
483 if delta.is_empty() || delta.iter().all(|b| b.num_rows() == 0) {
484 self.delta.clear();
485 if let Some(ref mut agg) = self.monotonic_agg {
487 agg.snapshot();
488 }
489 return Ok(false);
490 }
491
492 let delta_bytes: usize = delta.iter().map(batch_byte_size).sum();
494 if self.facts_bytes + delta_bytes > self.max_derived_bytes {
495 return Err(datafusion::error::DataFusionError::Execution(
496 LocyRuntimeError::MemoryLimitExceeded {
497 rule: self.rule_name.clone(),
498 bytes: self.facts_bytes + delta_bytes,
499 limit: self.max_derived_bytes,
500 }
501 .to_string(),
502 ));
503 }
504
505 if let Some(ref mut agg) = self.monotonic_agg {
507 agg.snapshot();
508 agg.update(
509 &self.key_column_indices,
510 &delta,
511 self.strict_probability_domain,
512 )?;
513 }
514
515 self.facts_bytes += delta_bytes;
517 self.facts.extend(delta.iter().cloned());
518 self.delta = delta;
519
520 Ok(true)
521 }
522
523 async fn compute_delta(
530 &mut self,
531 candidates: &[RecordBatch],
532 task_ctx: Option<&Arc<TaskContext>>,
533 ) -> DFResult<Vec<RecordBatch>> {
534 let total_existing: usize = self.facts.iter().map(|b| b.num_rows()).sum();
535 if total_existing >= DEDUP_ANTI_JOIN_THRESHOLD
536 && let Some(ctx) = task_ctx
537 {
538 return arrow_left_anti_dedup(candidates.to_vec(), &self.facts, &self.schema, ctx)
539 .await;
540 }
541 if let Some(ref mut rd) = self.row_dedup {
542 rd.compute_delta(candidates, &self.schema)
543 } else {
544 self.compute_delta_legacy(candidates)
545 }
546 }
547
548 fn compute_delta_legacy(&self, candidates: &[RecordBatch]) -> DFResult<Vec<RecordBatch>> {
552 let mut existing: HashSet<Vec<ScalarKey>> = HashSet::new();
554 for batch in &self.facts {
555 for row_idx in 0..batch.num_rows() {
556 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
557 existing.insert(key);
558 }
559 }
560
561 let mut delta_batches = Vec::new();
562 for batch in candidates {
563 if batch.num_rows() == 0 {
564 continue;
565 }
566 let mut keep = Vec::with_capacity(batch.num_rows());
568 for row_idx in 0..batch.num_rows() {
569 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
570 keep.push(!existing.contains(&key));
571 }
572
573 for (row_idx, kept) in keep.iter_mut().enumerate() {
575 if *kept {
576 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
577 if !existing.insert(key) {
578 *kept = false;
579 }
580 }
581 }
582
583 let keep_mask = arrow_array::BooleanArray::from(keep);
584 let new_rows = batch
585 .columns()
586 .iter()
587 .map(|col| {
588 arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
589 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
590 })
591 })
592 .collect::<DFResult<Vec<_>>>()?;
593
594 if new_rows.first().is_some_and(|c| !c.is_empty()) {
595 let filtered =
596 RecordBatch::try_new(Arc::clone(&self.schema), new_rows).map_err(|e| {
597 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
598 })?;
599 delta_batches.push(filtered);
600 }
601 }
602
603 Ok(delta_batches)
604 }
605
606 pub fn is_converged(&self) -> bool {
608 let delta_empty = self.delta.is_empty() || self.delta.iter().all(|b| b.num_rows() == 0);
609 let agg_stable = self.monotonic_agg.as_ref().is_none_or(|a| a.is_stable());
610 delta_empty && agg_stable
611 }
612
613 pub fn all_facts(&self) -> &[RecordBatch] {
615 &self.facts
616 }
617
618 pub fn all_delta(&self) -> &[RecordBatch] {
620 &self.delta
621 }
622
623 pub fn into_facts(self) -> Vec<RecordBatch> {
625 self.facts
626 }
627
628 pub fn merge_best_by(
639 &mut self,
640 candidates: Vec<RecordBatch>,
641 sort_criteria: &[SortCriterion],
642 ) -> DFResult<bool> {
643 if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
644 self.delta.clear();
645 return Ok(false);
646 }
647
648 if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
650 self.reconcile_schema(&first.schema());
651 }
652
653 let candidates = round_float_columns(&candidates);
655
656 let old_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> =
658 self.build_key_criteria_map(sort_criteria);
659
660 let mut all_batches = self.facts.clone();
662 all_batches.extend(candidates);
663 let all_batches: Vec<_> = all_batches
664 .into_iter()
665 .filter(|b| b.num_rows() > 0)
666 .collect();
667 if all_batches.is_empty() {
668 self.delta.clear();
669 return Ok(false);
670 }
671
672 let combined = arrow::compute::concat_batches(&self.schema, &all_batches)
673 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
674
675 if combined.num_rows() == 0 {
676 self.delta.clear();
677 return Ok(false);
678 }
679
680 let mut sort_columns = Vec::new();
683 for &ki in &self.key_column_indices {
684 sort_columns.push(arrow::compute::SortColumn {
685 values: Arc::clone(combined.column(ki)),
686 options: Some(arrow::compute::SortOptions {
687 descending: false,
688 nulls_first: false,
689 }),
690 });
691 }
692 for criterion in sort_criteria {
693 sort_columns.push(arrow::compute::SortColumn {
694 values: Arc::clone(combined.column(criterion.col_index)),
695 options: Some(arrow::compute::SortOptions {
696 descending: !criterion.ascending,
697 nulls_first: criterion.nulls_first,
698 }),
699 });
700 }
701
702 let sorted_indices =
703 arrow::compute::lexsort_to_indices(&sort_columns, None).map_err(arrow_err)?;
704 let sorted_columns: Vec<_> = combined
705 .columns()
706 .iter()
707 .map(|col| arrow::compute::take(col.as_ref(), &sorted_indices, None))
708 .collect::<Result<Vec<_>, _>>()
709 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
710 let sorted = RecordBatch::try_new(Arc::clone(&self.schema), sorted_columns)
711 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
712
713 let mut keep_indices: Vec<u32> = Vec::new();
715 let mut prev_key: Option<Vec<ScalarKey>> = None;
716 for row_idx in 0..sorted.num_rows() {
717 let key = extract_scalar_key(&sorted, &self.key_column_indices, row_idx);
718 let is_new_group = match &prev_key {
719 None => true,
720 Some(prev) => *prev != key,
721 };
722 if is_new_group {
723 keep_indices.push(row_idx as u32);
724 prev_key = Some(key);
725 }
726 }
727
728 let keep_array = arrow_array::UInt32Array::from(keep_indices);
729 let output_columns: Vec<_> = sorted
730 .columns()
731 .iter()
732 .map(|col| arrow::compute::take(col.as_ref(), &keep_array, None))
733 .collect::<Result<Vec<_>, _>>()
734 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
735 let pruned = RecordBatch::try_new(Arc::clone(&self.schema), output_columns)
736 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
737
738 let new_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> = {
740 let mut map = HashMap::new();
741 for row_idx in 0..pruned.num_rows() {
742 let key = extract_scalar_key(&pruned, &self.key_column_indices, row_idx);
743 let criteria: Vec<ScalarKey> = sort_criteria
744 .iter()
745 .flat_map(|c| extract_scalar_key(&pruned, &[c.col_index], row_idx))
746 .collect();
747 map.insert(key, criteria);
748 }
749 map
750 };
751 let changed = old_best != new_best;
752
753 tracing::debug!(
754 rule = %self.rule_name,
755 old_keys = old_best.len(),
756 new_keys = new_best.len(),
757 changed = changed,
758 "BEST BY merge"
759 );
760
761 self.facts_bytes = batch_byte_size(&pruned);
763 self.facts = vec![pruned];
764 if changed {
765 self.delta = self.facts.clone();
768 } else {
769 self.delta.clear();
770 }
771
772 self.row_dedup = RowDedupState::try_new(&self.schema);
774 if let Some(ref mut rd) = self.row_dedup {
775 rd.ingest_existing(&self.facts, &self.schema);
776 }
777
778 Ok(changed)
779 }
780
781 fn build_key_criteria_map(
783 &self,
784 sort_criteria: &[SortCriterion],
785 ) -> HashMap<Vec<ScalarKey>, Vec<ScalarKey>> {
786 let mut map = HashMap::new();
787 for batch in &self.facts {
788 for row_idx in 0..batch.num_rows() {
789 let key = extract_scalar_key(batch, &self.key_column_indices, row_idx);
790 let criteria: Vec<ScalarKey> = sort_criteria
791 .iter()
792 .flat_map(|c| extract_scalar_key(batch, &[c.col_index], row_idx))
793 .collect();
794 map.insert(key, criteria);
795 }
796 }
797 map
798 }
799}
800
801fn batch_byte_size(batch: &RecordBatch) -> usize {
803 batch
804 .columns()
805 .iter()
806 .map(|col| col.get_buffer_memory_size())
807 .sum()
808}
809
810fn round_float_columns(batches: &[RecordBatch]) -> Vec<RecordBatch> {
816 batches
817 .iter()
818 .map(|batch| {
819 let schema = batch.schema();
820 let has_float = schema
821 .fields()
822 .iter()
823 .any(|f| *f.data_type() == arrow_schema::DataType::Float64);
824 if !has_float {
825 return batch.clone();
826 }
827
828 let columns: Vec<arrow_array::ArrayRef> = batch
829 .columns()
830 .iter()
831 .enumerate()
832 .map(|(i, col)| {
833 if *schema.field(i).data_type() == arrow_schema::DataType::Float64 {
834 let arr = col
835 .as_any()
836 .downcast_ref::<arrow_array::Float64Array>()
837 .unwrap();
838 let rounded: arrow_array::Float64Array = arr
839 .iter()
840 .map(|v| v.map(|f| (f * 1e12).round() / 1e12))
841 .collect();
842 Arc::new(rounded) as arrow_array::ArrayRef
843 } else {
844 Arc::clone(col)
845 }
846 })
847 .collect();
848
849 RecordBatch::try_new(schema, columns).unwrap_or_else(|_| batch.clone())
850 })
851 .collect()
852}
853
854const DEDUP_ANTI_JOIN_THRESHOLD: usize = 300;
864
865async fn arrow_left_anti_dedup(
870 candidates: Vec<RecordBatch>,
871 existing: &[RecordBatch],
872 schema: &SchemaRef,
873 task_ctx: &Arc<TaskContext>,
874) -> DFResult<Vec<RecordBatch>> {
875 if existing.is_empty() || existing.iter().all(|b| b.num_rows() == 0) {
876 return Ok(candidates);
877 }
878
879 let left: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(candidates, Arc::clone(schema)));
880 let right: Arc<dyn ExecutionPlan> =
881 Arc::new(InMemoryExec::new(existing.to_vec(), Arc::clone(schema)));
882
883 let on: Vec<(
884 Arc<dyn datafusion::physical_plan::PhysicalExpr>,
885 Arc<dyn datafusion::physical_plan::PhysicalExpr>,
886 )> = schema
887 .fields()
888 .iter()
889 .enumerate()
890 .map(|(i, field)| {
891 let l: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
892 datafusion::physical_plan::expressions::Column::new(field.name(), i),
893 );
894 let r: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
895 datafusion::physical_plan::expressions::Column::new(field.name(), i),
896 );
897 (l, r)
898 })
899 .collect();
900
901 if on.is_empty() {
902 return Ok(vec![]);
903 }
904
905 let join = HashJoinExec::try_new(
906 left,
907 right,
908 on,
909 None,
910 &JoinType::LeftAnti,
911 None,
912 PartitionMode::CollectLeft,
913 datafusion::common::NullEquality::NullEqualsNull,
914 )?;
915
916 let join_arc: Arc<dyn ExecutionPlan> = Arc::new(join);
917 collect_all_partitions(&join_arc, task_ctx.clone()).await
918}
919
920#[derive(Debug, Clone)]
926pub struct IsRefBinding {
927 pub derived_scan_index: usize,
929 pub rule_name: String,
931 pub is_self_ref: bool,
933 pub negated: bool,
935 pub anti_join_cols: Vec<(String, String)>,
941 pub target_has_prob: bool,
943 pub target_prob_col: Option<String>,
945 pub provenance_join_cols: Vec<(String, String)>,
950}
951
952#[derive(Debug)]
954pub struct FixpointClausePlan {
955 pub body_logical: LogicalPlan,
957 pub is_ref_bindings: Vec<IsRefBinding>,
959 pub priority: Option<i64>,
961 pub along_bindings: Vec<String>,
963}
964
965#[derive(Debug)]
967pub struct FixpointRulePlan {
968 pub name: String,
970 pub clauses: Vec<FixpointClausePlan>,
972 pub yield_schema: SchemaRef,
974 pub key_column_indices: Vec<usize>,
976 pub priority: Option<i64>,
978 pub has_fold: bool,
980 pub fold_bindings: Vec<FoldBinding>,
982 pub has_best_by: bool,
984 pub best_by_criteria: Vec<SortCriterion>,
986 pub has_priority: bool,
988 pub deterministic: bool,
992 pub prob_column_name: Option<String>,
994}
995
996#[expect(clippy::too_many_arguments, reason = "Fixpoint loop needs all context")]
1005async fn run_fixpoint_loop(
1006 rules: Vec<FixpointRulePlan>,
1007 max_iterations: usize,
1008 timeout: Duration,
1009 graph_ctx: Arc<GraphExecutionContext>,
1010 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
1011 storage: Arc<StorageManager>,
1012 schema_info: Arc<UniSchema>,
1013 params: HashMap<String, Value>,
1014 registry: Arc<DerivedScanRegistry>,
1015 output_schema: SchemaRef,
1016 max_derived_bytes: usize,
1017 derivation_tracker: Option<Arc<ProvenanceStore>>,
1018 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
1019 strict_probability_domain: bool,
1020 probability_epsilon: f64,
1021 exact_probability: bool,
1022 max_bdd_variables: usize,
1023 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
1024 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
1025 top_k_proofs: usize,
1026) -> DFResult<Vec<RecordBatch>> {
1027 let start = Instant::now();
1028 let task_ctx = session_ctx.read().task_ctx();
1029
1030 let mut states: Vec<FixpointState> = rules
1032 .iter()
1033 .map(|rule| {
1034 let monotonic_agg = if !rule.fold_bindings.is_empty() {
1035 let bindings: Vec<MonotonicFoldBinding> = rule
1036 .fold_bindings
1037 .iter()
1038 .map(|fb| MonotonicFoldBinding {
1039 fold_name: fb.output_name.clone(),
1040 kind: fb.kind.clone(),
1041 input_col_index: fb.input_col_index,
1042 })
1043 .collect();
1044 Some(MonotonicAggState::new(bindings))
1045 } else {
1046 None
1047 };
1048 FixpointState::new(
1049 rule.name.clone(),
1050 Arc::clone(&rule.yield_schema),
1051 rule.key_column_indices.clone(),
1052 max_derived_bytes,
1053 monotonic_agg,
1054 strict_probability_domain,
1055 )
1056 })
1057 .collect();
1058
1059 let mut converged = false;
1061 let mut total_iters = 0usize;
1062 for iteration in 0..max_iterations {
1063 total_iters = iteration + 1;
1064 tracing::debug!("fixpoint iteration {}", iteration);
1065 let mut any_changed = false;
1066
1067 for rule_idx in 0..rules.len() {
1068 let rule = &rules[rule_idx];
1069
1070 update_derived_scan_handles(®istry, &states, rule_idx, &rules);
1072
1073 let mut all_candidates = Vec::new();
1075 let mut clause_candidates: Vec<Vec<RecordBatch>> = Vec::new();
1076 for clause in &rule.clauses {
1077 let mut batches = execute_subplan(
1078 &clause.body_logical,
1079 ¶ms,
1080 &HashMap::new(),
1081 &graph_ctx,
1082 &session_ctx,
1083 &storage,
1084 &schema_info,
1085 )
1086 .await?;
1087 for binding in &clause.is_ref_bindings {
1089 if binding.negated
1090 && !binding.anti_join_cols.is_empty()
1091 && let Some(entry) = registry.get(binding.derived_scan_index)
1092 {
1093 let neg_facts = entry.data.read().clone();
1094 if !neg_facts.is_empty() {
1095 if binding.target_has_prob && rule.prob_column_name.is_some() {
1096 let complement_col =
1098 format!("__prob_complement_{}", binding.rule_name);
1099 if let Some(prob_col) = &binding.target_prob_col {
1100 batches = apply_prob_complement_composite(
1101 batches,
1102 &neg_facts,
1103 &binding.anti_join_cols,
1104 prob_col,
1105 &complement_col,
1106 )?;
1107 } else {
1108 batches = apply_anti_join_composite(
1110 batches,
1111 &neg_facts,
1112 &binding.anti_join_cols,
1113 )?;
1114 }
1115 } else {
1116 batches = apply_anti_join_composite(
1118 batches,
1119 &neg_facts,
1120 &binding.anti_join_cols,
1121 )?;
1122 }
1123 }
1124 }
1125 }
1126 let complement_cols: Vec<String> = if !batches.is_empty() {
1128 batches[0]
1129 .schema()
1130 .fields()
1131 .iter()
1132 .filter(|f| f.name().starts_with("__prob_complement_"))
1133 .map(|f| f.name().clone())
1134 .collect()
1135 } else {
1136 vec![]
1137 };
1138 if !complement_cols.is_empty() {
1139 batches = multiply_prob_factors(
1140 batches,
1141 rule.prob_column_name.as_deref(),
1142 &complement_cols,
1143 )?;
1144 }
1145
1146 clause_candidates.push(batches.clone());
1147 all_candidates.extend(batches);
1148 }
1149
1150 let changed = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
1154 states[rule_idx].merge_best_by(all_candidates, &rule.best_by_criteria)?
1155 } else {
1156 states[rule_idx]
1157 .merge_delta(all_candidates, Some(Arc::clone(&task_ctx)))
1158 .await?
1159 };
1160 if changed {
1161 any_changed = true;
1162 if let Some(ref tracker) = derivation_tracker {
1164 record_provenance(
1165 tracker,
1166 rule,
1167 &states[rule_idx],
1168 &clause_candidates,
1169 iteration,
1170 ®istry,
1171 top_k_proofs,
1172 );
1173 }
1174 }
1175 }
1176
1177 if !any_changed && states.iter().all(|s| s.is_converged()) {
1179 tracing::debug!("fixpoint converged after {} iterations", iteration + 1);
1180 converged = true;
1181 break;
1182 }
1183
1184 if start.elapsed() > timeout {
1186 return Err(datafusion::error::DataFusionError::Execution(
1187 LocyRuntimeError::NonConvergence {
1188 iterations: iteration + 1,
1189 }
1190 .to_string(),
1191 ));
1192 }
1193 }
1194
1195 if let Ok(mut counts) = iteration_counts.write() {
1197 for rule in &rules {
1198 counts.insert(rule.name.clone(), total_iters);
1199 }
1200 }
1201
1202 if !converged {
1204 return Err(datafusion::error::DataFusionError::Execution(
1205 LocyRuntimeError::NonConvergence {
1206 iterations: max_iterations,
1207 }
1208 .to_string(),
1209 ));
1210 }
1211
1212 let task_ctx = session_ctx.read().task_ctx();
1214 let mut all_output = Vec::new();
1215
1216 for (rule_idx, state) in states.into_iter().enumerate() {
1217 let rule = &rules[rule_idx];
1218 let mut facts = state.into_facts();
1219 if facts.is_empty() {
1220 continue;
1221 }
1222
1223 let shared_info = if let Some(ref tracker) = derivation_tracker {
1225 detect_shared_lineage(rule, &facts, tracker, &warnings_slot)
1226 } else {
1227 None
1228 };
1229
1230 if exact_probability
1232 && let Some(ref info) = shared_info
1233 && let Some(ref tracker) = derivation_tracker
1234 {
1235 facts = apply_exact_wmc(
1236 facts,
1237 rule,
1238 info,
1239 tracker,
1240 max_bdd_variables,
1241 &warnings_slot,
1242 &approximate_slot,
1243 )?;
1244 }
1245
1246 let processed = apply_post_fixpoint_chain(
1247 facts,
1248 rule,
1249 &task_ctx,
1250 strict_probability_domain,
1251 probability_epsilon,
1252 )
1253 .await?;
1254 all_output.extend(processed);
1255 }
1256
1257 if all_output.is_empty() {
1259 all_output.push(RecordBatch::new_empty(output_schema));
1260 }
1261
1262 Ok(all_output)
1263}
1264
1265fn record_provenance(
1274 tracker: &Arc<ProvenanceStore>,
1275 rule: &FixpointRulePlan,
1276 state: &FixpointState,
1277 clause_candidates: &[Vec<RecordBatch>],
1278 iteration: usize,
1279 registry: &Arc<DerivedScanRegistry>,
1280 top_k_proofs: usize,
1281) {
1282 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1283
1284 let base_probs = if top_k_proofs > 0 {
1286 tracker.base_fact_probs()
1287 } else {
1288 HashMap::new()
1289 };
1290
1291 for delta_batch in state.all_delta() {
1292 for row_idx in 0..delta_batch.num_rows() {
1293 let row_hash = format!(
1294 "{:?}",
1295 extract_scalar_key(delta_batch, &all_indices, row_idx)
1296 )
1297 .into_bytes();
1298 let fact_row = batch_row_to_value_map(delta_batch, row_idx);
1299 let clause_index =
1300 find_clause_for_row(delta_batch, row_idx, &all_indices, clause_candidates);
1301
1302 let support = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
1303
1304 let proof_probability = if top_k_proofs > 0 {
1305 compute_proof_probability(&support, &base_probs)
1306 } else {
1307 None
1308 };
1309
1310 let entry = ProvenanceAnnotation {
1311 rule_name: rule.name.clone(),
1312 clause_index,
1313 support,
1314 along_values: {
1315 let along_names: Vec<String> = rule
1316 .clauses
1317 .get(clause_index)
1318 .map(|c| c.along_bindings.clone())
1319 .unwrap_or_default();
1320 along_names
1321 .iter()
1322 .filter_map(|name| fact_row.get(name).map(|v| (name.clone(), v.clone())))
1323 .collect()
1324 },
1325 iteration,
1326 fact_row,
1327 proof_probability,
1328 };
1329 if top_k_proofs > 0 {
1330 tracker.record_top_k(row_hash, entry, top_k_proofs);
1331 } else {
1332 tracker.record(row_hash, entry);
1333 }
1334 }
1335 }
1336}
1337
1338fn collect_is_ref_inputs(
1344 rule: &FixpointRulePlan,
1345 clause_index: usize,
1346 delta_batch: &RecordBatch,
1347 row_idx: usize,
1348 registry: &Arc<DerivedScanRegistry>,
1349) -> Vec<ProofTerm> {
1350 let clause = match rule.clauses.get(clause_index) {
1351 Some(c) => c,
1352 None => return vec![],
1353 };
1354
1355 let mut inputs = Vec::new();
1356 let delta_schema = delta_batch.schema();
1357
1358 for binding in &clause.is_ref_bindings {
1359 if binding.negated {
1360 continue;
1361 }
1362 if binding.provenance_join_cols.is_empty() {
1363 continue;
1364 }
1365
1366 let body_values: Vec<(String, ScalarKey)> = binding
1368 .provenance_join_cols
1369 .iter()
1370 .filter_map(|(body_col, _derived_col)| {
1371 let col_idx = delta_schema
1372 .fields()
1373 .iter()
1374 .position(|f| f.name() == body_col)?;
1375 let key = extract_scalar_key(delta_batch, &[col_idx], row_idx);
1376 Some((body_col.clone(), key.into_iter().next()?))
1377 })
1378 .collect();
1379
1380 if body_values.len() != binding.provenance_join_cols.len() {
1381 continue;
1382 }
1383
1384 let entry = match registry.get(binding.derived_scan_index) {
1386 Some(e) => e,
1387 None => continue,
1388 };
1389 let source_batches = entry.data.read();
1390 let source_schema = &entry.schema;
1391
1392 for src_batch in source_batches.iter() {
1394 let all_src_indices: Vec<usize> = (0..src_batch.num_columns()).collect();
1395 for src_row in 0..src_batch.num_rows() {
1396 let matches = binding.provenance_join_cols.iter().enumerate().all(
1397 |(i, (_body_col, derived_col))| {
1398 let src_col_idx = source_schema
1399 .fields()
1400 .iter()
1401 .position(|f| f.name() == derived_col);
1402 match src_col_idx {
1403 Some(idx) => {
1404 let src_key = extract_scalar_key(src_batch, &[idx], src_row);
1405 src_key.first() == Some(&body_values[i].1)
1406 }
1407 None => false,
1408 }
1409 },
1410 );
1411 if matches {
1412 let fact_hash = format!(
1413 "{:?}",
1414 extract_scalar_key(src_batch, &all_src_indices, src_row)
1415 )
1416 .into_bytes();
1417 inputs.push(ProofTerm {
1418 source_rule: binding.rule_name.clone(),
1419 base_fact_id: fact_hash,
1420 });
1421 }
1422 }
1423 }
1424 }
1425
1426 inputs
1427}
1428
1429#[expect(
1448 dead_code,
1449 reason = "Fields accessed via SharedLineageInfo in detect_shared_lineage"
1450)]
1451pub(crate) struct SharedGroupRow {
1452 pub fact_hash: Vec<u8>,
1453 pub lineage: HashSet<Vec<u8>>,
1454}
1455
1456pub(crate) struct SharedLineageInfo {
1458 pub shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>>,
1460}
1461
1462fn fact_hash_key(batch: &RecordBatch, all_indices: &[usize], row_idx: usize) -> Vec<u8> {
1464 format!("{:?}", extract_scalar_key(batch, all_indices, row_idx)).into_bytes()
1465}
1466
1467fn detect_shared_lineage(
1470 rule: &FixpointRulePlan,
1471 pre_fold_facts: &[RecordBatch],
1472 tracker: &Arc<ProvenanceStore>,
1473 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1474) -> Option<SharedLineageInfo> {
1475 use crate::query::df_graph::locy_fold::FoldAggKind;
1476 use uni_locy::{RuntimeWarning, RuntimeWarningCode};
1477
1478 let has_prob_fold = rule
1480 .fold_bindings
1481 .iter()
1482 .any(|fb| matches!(fb.kind, FoldAggKind::Nor | FoldAggKind::Prod));
1483 if !has_prob_fold {
1484 return None;
1485 }
1486
1487 let key_indices = &rule.key_column_indices;
1489 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1490
1491 let mut groups: HashMap<Vec<ScalarKey>, Vec<Vec<u8>>> = HashMap::new();
1492 for batch in pre_fold_facts {
1493 for row_idx in 0..batch.num_rows() {
1494 let key = extract_scalar_key(batch, key_indices, row_idx);
1495 let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
1496 groups.entry(key).or_default().push(fact_hash);
1497 }
1498 }
1499
1500 let mut shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>> = HashMap::new();
1501 let mut any_shared = false;
1502
1503 for (key, fact_hashes) in &groups {
1505 if fact_hashes.len() < 2 {
1506 continue;
1507 }
1508
1509 let mut has_inputs = false;
1511 let mut per_row_bases: Vec<HashSet<Vec<u8>>> = Vec::new();
1512 for fh in fact_hashes {
1513 let bases = compute_lineage(fh, tracker, &mut HashSet::new());
1514 if let Some(entry) = tracker.lookup(fh)
1515 && !entry.support.is_empty()
1516 {
1517 has_inputs = true;
1518 }
1519 per_row_bases.push(bases);
1520 }
1521
1522 let shared_found = if has_inputs {
1523 let mut found = false;
1525 'outer: for i in 0..per_row_bases.len() {
1526 for j in (i + 1)..per_row_bases.len() {
1527 if !per_row_bases[i].is_disjoint(&per_row_bases[j]) {
1528 found = true;
1529 break 'outer;
1530 }
1531 }
1532 }
1533 found
1534 } else {
1535 fact_hashes.iter().any(|fh| {
1538 tracker.lookup(fh).is_some_and(|entry| {
1539 rule.clauses
1540 .get(entry.clause_index)
1541 .is_some_and(|clause| clause.is_ref_bindings.iter().any(|b| !b.negated))
1542 })
1543 })
1544 };
1545
1546 if shared_found {
1547 any_shared = true;
1548 let rows: Vec<SharedGroupRow> = fact_hashes
1550 .iter()
1551 .zip(per_row_bases.into_iter())
1552 .map(|(fh, bases)| SharedGroupRow {
1553 fact_hash: fh.clone(),
1554 lineage: bases,
1555 })
1556 .collect();
1557 shared_groups.insert(key.clone(), rows);
1558 }
1559 }
1560
1561 {
1567 let mut input_to_groups: HashMap<Vec<u8>, HashSet<Vec<ScalarKey>>> = HashMap::new();
1568 for (key, fact_hashes) in &groups {
1569 for fh in fact_hashes {
1570 if let Some(entry) = tracker.lookup(fh) {
1571 for input in &entry.support {
1572 input_to_groups
1573 .entry(input.base_fact_id.clone())
1574 .or_default()
1575 .insert(key.clone());
1576 }
1577 }
1578 }
1579 }
1580 let has_cross_group = input_to_groups.values().any(|g| g.len() > 1);
1581 if has_cross_group && let Ok(mut warnings) = warnings_slot.write() {
1582 let already_warned = warnings.iter().any(|w| {
1583 w.code == RuntimeWarningCode::CrossGroupCorrelationNotExact
1584 && w.rule_name == rule.name
1585 });
1586 if !already_warned {
1587 warnings.push(RuntimeWarning {
1588 code: RuntimeWarningCode::CrossGroupCorrelationNotExact,
1589 message: format!(
1590 "Rule '{}': IS-ref base facts are shared across different KEY \
1591 groups. BDD corrects per-group probabilities but cannot account \
1592 for cross-group correlations.",
1593 rule.name
1594 ),
1595 rule_name: rule.name.clone(),
1596 variable_count: None,
1597 key_group: None,
1598 });
1599 }
1600 }
1601 }
1602
1603 if any_shared {
1604 if let Ok(mut warnings) = warnings_slot.write() {
1605 let already_warned = warnings.iter().any(|w| {
1606 w.code == RuntimeWarningCode::SharedProbabilisticDependency
1607 && w.rule_name == rule.name
1608 });
1609 if !already_warned {
1610 warnings.push(RuntimeWarning {
1611 code: RuntimeWarningCode::SharedProbabilisticDependency,
1612 message: format!(
1613 "Rule '{}' aggregates with MNOR/MPROD but some proof paths \
1614 share intermediate facts, violating the independence assumption. \
1615 Results may overestimate probability.",
1616 rule.name
1617 ),
1618 rule_name: rule.name.clone(),
1619 variable_count: None,
1620 key_group: None,
1621 });
1622 }
1623 }
1624 Some(SharedLineageInfo { shared_groups })
1625 } else {
1626 None
1627 }
1628}
1629
1630pub(crate) fn record_and_detect_lineage_nonrecursive(
1638 rule: &FixpointRulePlan,
1639 tagged_clause_facts: &[(usize, Vec<RecordBatch>)],
1640 tracker: &Arc<ProvenanceStore>,
1641 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1642 registry: &Arc<DerivedScanRegistry>,
1643 top_k_proofs: usize,
1644) -> Option<SharedLineageInfo> {
1645 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1646
1647 let base_probs = if top_k_proofs > 0 {
1649 tracker.base_fact_probs()
1650 } else {
1651 HashMap::new()
1652 };
1653
1654 for (clause_index, batches) in tagged_clause_facts {
1656 for batch in batches {
1657 for row_idx in 0..batch.num_rows() {
1658 let row_hash = fact_hash_key(batch, &all_indices, row_idx);
1659 let fact_row = batch_row_to_value_map(batch, row_idx);
1660
1661 let support = collect_is_ref_inputs(rule, *clause_index, batch, row_idx, registry);
1662
1663 let proof_probability = if top_k_proofs > 0 {
1664 compute_proof_probability(&support, &base_probs)
1665 } else {
1666 None
1667 };
1668
1669 let entry = ProvenanceAnnotation {
1670 rule_name: rule.name.clone(),
1671 clause_index: *clause_index,
1672 support,
1673 along_values: {
1674 let along_names: Vec<String> = rule
1675 .clauses
1676 .get(*clause_index)
1677 .map(|c| c.along_bindings.clone())
1678 .unwrap_or_default();
1679 along_names
1680 .iter()
1681 .filter_map(|name| {
1682 fact_row.get(name).map(|v| (name.clone(), v.clone()))
1683 })
1684 .collect()
1685 },
1686 iteration: 0,
1687 fact_row,
1688 proof_probability,
1689 };
1690 if top_k_proofs > 0 {
1691 tracker.record_top_k(row_hash, entry, top_k_proofs);
1692 } else {
1693 tracker.record(row_hash, entry);
1694 }
1695 }
1696 }
1697 }
1698
1699 let all_facts: Vec<RecordBatch> = tagged_clause_facts
1701 .iter()
1702 .flat_map(|(_, batches)| batches.iter().cloned())
1703 .collect();
1704 detect_shared_lineage(rule, &all_facts, tracker, warnings_slot)
1705}
1706
1707pub(crate) fn apply_exact_wmc(
1715 pre_fold_facts: Vec<RecordBatch>,
1716 rule: &FixpointRulePlan,
1717 shared_info: &SharedLineageInfo,
1718 tracker: &Arc<ProvenanceStore>,
1719 max_bdd_variables: usize,
1720 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1721 approximate_slot: &Arc<StdRwLock<HashMap<String, Vec<String>>>>,
1722) -> DFResult<Vec<RecordBatch>> {
1723 use crate::query::df_graph::locy_bdd::{SemiringOp, weighted_model_count};
1724 use crate::query::df_graph::locy_fold::FoldAggKind;
1725 use uni_locy::{RuntimeWarning, RuntimeWarningCode};
1726
1727 let prob_fold = rule
1729 .fold_bindings
1730 .iter()
1731 .find(|fb| matches!(fb.kind, FoldAggKind::Nor | FoldAggKind::Prod));
1732 let prob_fold = match prob_fold {
1733 Some(f) => f,
1734 None => return Ok(pre_fold_facts),
1735 };
1736 let semiring_op = if matches!(prob_fold.kind, FoldAggKind::Nor) {
1737 SemiringOp::Disjunction
1738 } else {
1739 SemiringOp::Conjunction
1740 };
1741 let prob_col_idx = prob_fold.input_col_index;
1742 let prob_col_name = rule.yield_schema.field(prob_col_idx).name().clone();
1743
1744 let key_indices = &rule.key_column_indices;
1745 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1746
1747 let shared_keys: HashSet<Vec<ScalarKey>> = shared_info.shared_groups.keys().cloned().collect();
1749
1750 struct GroupAccum {
1753 base_facts: Vec<HashSet<Vec<u8>>>,
1754 base_probs: HashMap<Vec<u8>, f64>,
1755 representative: (usize, usize),
1757 row_locations: Vec<(usize, usize)>,
1758 }
1759
1760 let mut group_accums: HashMap<Vec<ScalarKey>, GroupAccum> = HashMap::new();
1761 let mut non_shared_rows: Vec<(usize, usize)> = Vec::new(); for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
1764 for row_idx in 0..batch.num_rows() {
1765 let key = extract_scalar_key(batch, key_indices, row_idx);
1766 if shared_keys.contains(&key) {
1767 let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
1768 let bases = compute_lineage(&fact_hash, tracker, &mut HashSet::new());
1769
1770 let accum = group_accums.entry(key).or_insert_with(|| GroupAccum {
1771 base_facts: Vec::new(),
1772 base_probs: HashMap::new(),
1773 representative: (batch_idx, row_idx),
1774 row_locations: Vec::new(),
1775 });
1776
1777 for bf in &bases {
1779 if !accum.base_probs.contains_key(bf)
1780 && let Some(entry) = tracker.lookup(bf)
1781 && let Some(val) = entry.fact_row.get(&prob_col_name)
1782 && let Some(p) = value_to_f64(val)
1783 {
1784 accum.base_probs.insert(bf.clone(), p);
1785 }
1786 }
1787
1788 accum.base_facts.push(bases);
1789 accum.row_locations.push((batch_idx, row_idx));
1790 } else {
1791 non_shared_rows.push((batch_idx, row_idx));
1792 }
1793 }
1794 }
1795
1796 let mut keep_rows: HashSet<(usize, usize)> = HashSet::new();
1799 let mut overrides: HashMap<(usize, usize), f64> = HashMap::new();
1801
1802 for &loc in &non_shared_rows {
1804 keep_rows.insert(loc);
1805 }
1806
1807 for (key, accum) in &group_accums {
1808 let bdd_result = weighted_model_count(
1809 &accum.base_facts,
1810 &accum.base_probs,
1811 semiring_op,
1812 max_bdd_variables,
1813 );
1814
1815 if bdd_result.approximated {
1816 if let Ok(mut warnings) = warnings_slot.write() {
1818 let key_desc = format!("{key:?}");
1819 let already_warned = warnings.iter().any(|w| {
1820 w.code == RuntimeWarningCode::BddLimitExceeded
1821 && w.rule_name == rule.name
1822 && w.key_group.as_deref() == Some(&key_desc)
1823 });
1824 if !already_warned {
1825 warnings.push(RuntimeWarning {
1826 code: RuntimeWarningCode::BddLimitExceeded,
1827 message: format!(
1828 "Rule '{}': BDD variable limit exceeded ({} > {}). \
1829 Falling back to independence-mode result.",
1830 rule.name, bdd_result.variable_count, max_bdd_variables
1831 ),
1832 rule_name: rule.name.clone(),
1833 variable_count: Some(bdd_result.variable_count),
1834 key_group: Some(key_desc),
1835 });
1836 }
1837 }
1838 if let Ok(mut approx) = approximate_slot.write() {
1839 let key_desc = format!("{key:?}");
1840 approx.entry(rule.name.clone()).or_default().push(key_desc);
1841 }
1842 for &loc in &accum.row_locations {
1844 keep_rows.insert(loc);
1845 }
1846 } else {
1847 keep_rows.insert(accum.representative);
1849 overrides.insert(accum.representative, bdd_result.probability);
1850 }
1851 }
1852
1853 let mut result_batches = Vec::new();
1855 for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
1856 let kept_indices: Vec<usize> = (0..batch.num_rows())
1857 .filter(|&row_idx| keep_rows.contains(&(batch_idx, row_idx)))
1858 .collect();
1859
1860 if kept_indices.is_empty() {
1861 continue;
1862 }
1863
1864 let indices = arrow::array::UInt32Array::from(
1865 kept_indices.iter().map(|&i| i as u32).collect::<Vec<_>>(),
1866 );
1867 let mut columns: Vec<arrow::array::ArrayRef> = batch
1868 .columns()
1869 .iter()
1870 .map(|col| arrow::compute::take(col, &indices, None))
1871 .collect::<Result<Vec<_>, _>>()
1872 .map_err(arrow_err)?;
1873
1874 let override_map: Vec<Option<f64>> = kept_indices
1876 .iter()
1877 .map(|&row_idx| overrides.get(&(batch_idx, row_idx)).copied())
1878 .collect();
1879
1880 if override_map.iter().any(|o| o.is_some()) {
1881 let existing_prob = columns[prob_col_idx]
1883 .as_any()
1884 .downcast_ref::<arrow::array::Float64Array>();
1885 let new_values: Vec<f64> = override_map
1886 .iter()
1887 .enumerate()
1888 .map(|(i, ov)| match ov {
1889 Some(p) => *p,
1890 None => existing_prob.map(|arr| arr.value(i)).unwrap_or(0.0),
1891 })
1892 .collect();
1893 columns[prob_col_idx] = Arc::new(arrow::array::Float64Array::from(new_values));
1894 }
1895
1896 let result_batch = RecordBatch::try_new(batch.schema(), columns).map_err(arrow_err)?;
1897 result_batches.push(result_batch);
1898 }
1899
1900 Ok(result_batches)
1901}
1902
1903fn value_to_f64(val: &uni_common::Value) -> Option<f64> {
1905 match val {
1906 uni_common::Value::Float(f) => Some(*f),
1907 uni_common::Value::Int(i) => Some(*i as f64),
1908 _ => None,
1909 }
1910}
1911
1912fn compute_lineage(
1919 fact_hash: &[u8],
1920 tracker: &Arc<ProvenanceStore>,
1921 visited: &mut HashSet<Vec<u8>>,
1922) -> HashSet<Vec<u8>> {
1923 if !visited.insert(fact_hash.to_vec()) {
1924 return HashSet::new(); }
1926
1927 match tracker.lookup(fact_hash) {
1928 Some(entry) if !entry.support.is_empty() => {
1929 let mut bases = HashSet::new();
1930 for input in &entry.support {
1931 let child_bases = compute_lineage(&input.base_fact_id, tracker, visited);
1932 bases.extend(child_bases);
1933 }
1934 bases
1935 }
1936 _ => {
1937 let mut set = HashSet::new();
1939 set.insert(fact_hash.to_vec());
1940 set
1941 }
1942 }
1943}
1944
1945fn find_clause_for_row(
1950 delta_batch: &RecordBatch,
1951 row_idx: usize,
1952 all_indices: &[usize],
1953 clause_candidates: &[Vec<RecordBatch>],
1954) -> usize {
1955 let target_key = extract_scalar_key(delta_batch, all_indices, row_idx);
1956 for (clause_idx, batches) in clause_candidates.iter().enumerate() {
1957 for batch in batches {
1958 if batch.num_columns() != all_indices.len() {
1959 continue;
1960 }
1961 for r in 0..batch.num_rows() {
1962 if extract_scalar_key(batch, all_indices, r) == target_key {
1963 return clause_idx;
1964 }
1965 }
1966 }
1967 }
1968 0
1969}
1970
1971fn batch_row_to_value_map(
1973 batch: &RecordBatch,
1974 row_idx: usize,
1975) -> std::collections::HashMap<String, Value> {
1976 use uni_store::storage::arrow_convert::arrow_to_value;
1977
1978 let schema = batch.schema();
1979 schema
1980 .fields()
1981 .iter()
1982 .enumerate()
1983 .map(|(col_idx, field)| {
1984 let col = batch.column(col_idx);
1985 let val = arrow_to_value(col.as_ref(), row_idx, None);
1986 (field.name().clone(), val)
1987 })
1988 .collect()
1989}
1990
1991pub fn apply_anti_join(
1996 batches: Vec<RecordBatch>,
1997 neg_facts: &[RecordBatch],
1998 left_col: &str,
1999 right_col: &str,
2000) -> datafusion::error::Result<Vec<RecordBatch>> {
2001 use arrow::compute::filter_record_batch;
2002 use arrow_array::{Array as _, BooleanArray, UInt64Array};
2003
2004 let mut banned: std::collections::HashSet<u64> = std::collections::HashSet::new();
2006 for batch in neg_facts {
2007 let Ok(idx) = batch.schema().index_of(right_col) else {
2008 continue;
2009 };
2010 let arr = batch.column(idx);
2011 let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2012 continue;
2013 };
2014 for i in 0..vids.len() {
2015 if !vids.is_null(i) {
2016 banned.insert(vids.value(i));
2017 }
2018 }
2019 }
2020
2021 if banned.is_empty() {
2022 return Ok(batches);
2023 }
2024
2025 let mut result = Vec::new();
2027 for batch in batches {
2028 let Ok(idx) = batch.schema().index_of(left_col) else {
2029 result.push(batch);
2030 continue;
2031 };
2032 let arr = batch.column(idx);
2033 let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2034 result.push(batch);
2035 continue;
2036 };
2037 let keep: Vec<bool> = (0..vids.len())
2038 .map(|i| vids.is_null(i) || !banned.contains(&vids.value(i)))
2039 .collect();
2040 let keep_arr = BooleanArray::from(keep);
2041 let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
2042 if filtered.num_rows() > 0 {
2043 result.push(filtered);
2044 }
2045 }
2046 Ok(result)
2047}
2048
2049pub fn apply_prob_complement(
2058 batches: Vec<RecordBatch>,
2059 neg_facts: &[RecordBatch],
2060 left_col: &str,
2061 right_col: &str,
2062 prob_col: &str,
2063 complement_col_name: &str,
2064) -> datafusion::error::Result<Vec<RecordBatch>> {
2065 use arrow_array::{Array as _, Float64Array, UInt64Array};
2066
2067 let mut prob_map: std::collections::HashMap<u64, f64> = std::collections::HashMap::new();
2069 for batch in neg_facts {
2070 let Ok(vid_idx) = batch.schema().index_of(right_col) else {
2071 continue;
2072 };
2073 let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
2074 continue;
2075 };
2076 let Some(vids) = batch.column(vid_idx).as_any().downcast_ref::<UInt64Array>() else {
2077 continue;
2078 };
2079 let prob_arr = batch.column(prob_idx);
2080 let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
2081 for i in 0..vids.len() {
2082 if !vids.is_null(i) {
2083 let p = probs
2084 .and_then(|arr| {
2085 if arr.is_null(i) {
2086 None
2087 } else {
2088 Some(arr.value(i))
2089 }
2090 })
2091 .unwrap_or(0.0);
2092 prob_map
2095 .entry(vids.value(i))
2096 .and_modify(|existing| {
2097 *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
2098 })
2099 .or_insert(p);
2100 }
2101 }
2102 }
2103
2104 let mut result = Vec::new();
2106 for batch in batches {
2107 let Ok(idx) = batch.schema().index_of(left_col) else {
2108 result.push(batch);
2109 continue;
2110 };
2111 let Some(vids) = batch.column(idx).as_any().downcast_ref::<UInt64Array>() else {
2112 result.push(batch);
2113 continue;
2114 };
2115
2116 let complements: Vec<f64> = (0..vids.len())
2118 .map(|i| {
2119 if vids.is_null(i) {
2120 1.0
2121 } else {
2122 let p = prob_map.get(&vids.value(i)).copied().unwrap_or(0.0);
2123 1.0 - p
2124 }
2125 })
2126 .collect();
2127
2128 let complement_arr = Float64Array::from(complements);
2129
2130 let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
2132 columns.push(std::sync::Arc::new(complement_arr));
2133
2134 let mut fields: Vec<std::sync::Arc<arrow_schema::Field>> =
2135 batch.schema().fields().iter().cloned().collect();
2136 fields.push(std::sync::Arc::new(arrow_schema::Field::new(
2137 complement_col_name,
2138 arrow_schema::DataType::Float64,
2139 true,
2140 )));
2141
2142 let new_schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
2143 let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
2144 result.push(new_batch);
2145 }
2146 Ok(result)
2147}
2148
2149pub fn apply_prob_complement_composite(
2156 batches: Vec<RecordBatch>,
2157 neg_facts: &[RecordBatch],
2158 join_cols: &[(String, String)],
2159 prob_col: &str,
2160 complement_col_name: &str,
2161) -> datafusion::error::Result<Vec<RecordBatch>> {
2162 use arrow_array::{Array as _, Float64Array, UInt64Array};
2163
2164 let mut prob_map: HashMap<Vec<u64>, f64> = HashMap::new();
2166 for batch in neg_facts {
2167 let right_indices: Vec<usize> = join_cols
2168 .iter()
2169 .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
2170 .collect();
2171 if right_indices.len() != join_cols.len() {
2172 continue;
2173 }
2174 let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
2175 continue;
2176 };
2177 let prob_arr = batch.column(prob_idx);
2178 let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
2179 for row in 0..batch.num_rows() {
2180 let mut key = Vec::with_capacity(right_indices.len());
2181 let mut valid = true;
2182 for &ci in &right_indices {
2183 let col = batch.column(ci);
2184 if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
2185 if vids.is_null(row) {
2186 valid = false;
2187 break;
2188 }
2189 key.push(vids.value(row));
2190 } else {
2191 valid = false;
2192 break;
2193 }
2194 }
2195 if !valid {
2196 continue;
2197 }
2198 let p = probs
2199 .and_then(|arr| {
2200 if arr.is_null(row) {
2201 None
2202 } else {
2203 Some(arr.value(row))
2204 }
2205 })
2206 .unwrap_or(0.0);
2207 prob_map
2209 .entry(key)
2210 .and_modify(|existing| {
2211 *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
2212 })
2213 .or_insert(p);
2214 }
2215 }
2216
2217 let mut result = Vec::new();
2219 for batch in batches {
2220 let left_indices: Vec<usize> = join_cols
2221 .iter()
2222 .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
2223 .collect();
2224 if left_indices.len() != join_cols.len() {
2225 result.push(batch);
2226 continue;
2227 }
2228 let all_u64 = left_indices.iter().all(|&ci| {
2229 batch
2230 .column(ci)
2231 .as_any()
2232 .downcast_ref::<UInt64Array>()
2233 .is_some()
2234 });
2235 if !all_u64 {
2236 result.push(batch);
2237 continue;
2238 }
2239
2240 let complements: Vec<f64> = (0..batch.num_rows())
2241 .map(|row| {
2242 let mut key = Vec::with_capacity(left_indices.len());
2243 for &ci in &left_indices {
2244 let vids = batch
2245 .column(ci)
2246 .as_any()
2247 .downcast_ref::<UInt64Array>()
2248 .unwrap();
2249 if vids.is_null(row) {
2250 return 1.0;
2251 }
2252 key.push(vids.value(row));
2253 }
2254 let p = prob_map.get(&key).copied().unwrap_or(0.0);
2255 1.0 - p
2256 })
2257 .collect();
2258
2259 let complement_arr = Float64Array::from(complements);
2260 let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
2261 columns.push(Arc::new(complement_arr));
2262
2263 let mut fields: Vec<Arc<arrow_schema::Field>> =
2264 batch.schema().fields().iter().cloned().collect();
2265 fields.push(Arc::new(arrow_schema::Field::new(
2266 complement_col_name,
2267 arrow_schema::DataType::Float64,
2268 true,
2269 )));
2270
2271 let new_schema = Arc::new(arrow_schema::Schema::new(fields));
2272 let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
2273 result.push(new_batch);
2274 }
2275 Ok(result)
2276}
2277
2278pub fn apply_anti_join_composite(
2284 batches: Vec<RecordBatch>,
2285 neg_facts: &[RecordBatch],
2286 join_cols: &[(String, String)],
2287) -> datafusion::error::Result<Vec<RecordBatch>> {
2288 use arrow::compute::filter_record_batch;
2289 use arrow_array::{Array as _, BooleanArray, UInt64Array};
2290
2291 let mut banned: HashSet<Vec<u64>> = HashSet::new();
2293 for batch in neg_facts {
2294 let right_indices: Vec<usize> = join_cols
2295 .iter()
2296 .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
2297 .collect();
2298 if right_indices.len() != join_cols.len() {
2299 continue;
2300 }
2301 for row in 0..batch.num_rows() {
2302 let mut key = Vec::with_capacity(right_indices.len());
2303 let mut valid = true;
2304 for &ci in &right_indices {
2305 let col = batch.column(ci);
2306 if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
2307 if vids.is_null(row) {
2308 valid = false;
2309 break;
2310 }
2311 key.push(vids.value(row));
2312 } else {
2313 valid = false;
2314 break;
2315 }
2316 }
2317 if valid {
2318 banned.insert(key);
2319 }
2320 }
2321 }
2322
2323 if banned.is_empty() {
2324 return Ok(batches);
2325 }
2326
2327 let mut result = Vec::new();
2329 for batch in batches {
2330 let left_indices: Vec<usize> = join_cols
2331 .iter()
2332 .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
2333 .collect();
2334 if left_indices.len() != join_cols.len() {
2335 result.push(batch);
2336 continue;
2337 }
2338 let all_u64 = left_indices.iter().all(|&ci| {
2339 batch
2340 .column(ci)
2341 .as_any()
2342 .downcast_ref::<UInt64Array>()
2343 .is_some()
2344 });
2345 if !all_u64 {
2346 result.push(batch);
2347 continue;
2348 }
2349
2350 let keep: Vec<bool> = (0..batch.num_rows())
2351 .map(|row| {
2352 let mut key = Vec::with_capacity(left_indices.len());
2353 for &ci in &left_indices {
2354 let vids = batch
2355 .column(ci)
2356 .as_any()
2357 .downcast_ref::<UInt64Array>()
2358 .unwrap();
2359 if vids.is_null(row) {
2360 return true; }
2362 key.push(vids.value(row));
2363 }
2364 !banned.contains(&key)
2365 })
2366 .collect();
2367 let keep_arr = BooleanArray::from(keep);
2368 let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
2369 if filtered.num_rows() > 0 {
2370 result.push(filtered);
2371 }
2372 }
2373 Ok(result)
2374}
2375
2376pub fn multiply_prob_factors(
2387 batches: Vec<RecordBatch>,
2388 prob_col: Option<&str>,
2389 complement_cols: &[String],
2390) -> datafusion::error::Result<Vec<RecordBatch>> {
2391 use arrow_array::{Array as _, Float64Array};
2392
2393 let mut result = Vec::with_capacity(batches.len());
2394
2395 for batch in batches {
2396 if batch.num_rows() == 0 {
2397 let keep: Vec<usize> = batch
2399 .schema()
2400 .fields()
2401 .iter()
2402 .enumerate()
2403 .filter(|(_, f)| !complement_cols.contains(f.name()))
2404 .map(|(i, _)| i)
2405 .collect();
2406 let fields: Vec<_> = keep
2407 .iter()
2408 .map(|&i| batch.schema().field(i).clone())
2409 .collect();
2410 let cols: Vec<_> = keep.iter().map(|&i| batch.column(i).clone()).collect();
2411 let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
2412 result.push(
2413 RecordBatch::try_new(schema, cols).map_err(|e| {
2414 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
2415 })?,
2416 );
2417 continue;
2418 }
2419
2420 let num_rows = batch.num_rows();
2421
2422 let mut combined = vec![1.0f64; num_rows];
2424 for col_name in complement_cols {
2425 if let Ok(idx) = batch.schema().index_of(col_name) {
2426 let arr = batch
2427 .column(idx)
2428 .as_any()
2429 .downcast_ref::<Float64Array>()
2430 .ok_or_else(|| {
2431 datafusion::error::DataFusionError::Internal(format!(
2432 "Expected Float64 for complement column {col_name}"
2433 ))
2434 })?;
2435 for (i, val) in combined.iter_mut().enumerate().take(num_rows) {
2436 if !arr.is_null(i) {
2437 *val *= arr.value(i);
2438 }
2439 }
2440 }
2441 }
2442
2443 let final_prob: Vec<f64> = if let Some(prob_name) = prob_col {
2445 if let Ok(idx) = batch.schema().index_of(prob_name) {
2446 let arr = batch
2447 .column(idx)
2448 .as_any()
2449 .downcast_ref::<Float64Array>()
2450 .ok_or_else(|| {
2451 datafusion::error::DataFusionError::Internal(format!(
2452 "Expected Float64 for PROB column {prob_name}"
2453 ))
2454 })?;
2455 (0..num_rows)
2456 .map(|i| {
2457 if arr.is_null(i) {
2458 combined[i]
2459 } else {
2460 arr.value(i) * combined[i]
2461 }
2462 })
2463 .collect()
2464 } else {
2465 combined
2466 }
2467 } else {
2468 combined
2469 };
2470
2471 let new_prob_array: arrow_array::ArrayRef =
2472 std::sync::Arc::new(Float64Array::from(final_prob));
2473
2474 let mut fields = Vec::new();
2476 let mut columns = Vec::new();
2477
2478 for (idx, field) in batch.schema().fields().iter().enumerate() {
2479 if complement_cols.contains(field.name()) {
2480 continue;
2481 }
2482 if prob_col.is_some_and(|p| field.name() == p) {
2483 fields.push(field.clone());
2484 columns.push(new_prob_array.clone());
2485 } else {
2486 fields.push(field.clone());
2487 columns.push(batch.column(idx).clone());
2488 }
2489 }
2490
2491 let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
2492 result.push(RecordBatch::try_new(schema, columns).map_err(arrow_err)?);
2493 }
2494
2495 Ok(result)
2496}
2497
2498fn update_derived_scan_handles(
2503 registry: &DerivedScanRegistry,
2504 states: &[FixpointState],
2505 current_rule_idx: usize,
2506 rules: &[FixpointRulePlan],
2507) {
2508 let current_rule_name = &rules[current_rule_idx].name;
2509
2510 for entry in ®istry.entries {
2511 let source_state_idx = rules.iter().position(|r| r.name == entry.rule_name);
2513 let Some(source_idx) = source_state_idx else {
2514 continue;
2515 };
2516
2517 let is_self = entry.rule_name == *current_rule_name;
2518 let data = if is_self {
2519 states[source_idx].all_delta().to_vec()
2521 } else {
2522 states[source_idx].all_facts().to_vec()
2524 };
2525
2526 let data = if data.is_empty() || data.iter().all(|b| b.num_rows() == 0) {
2528 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
2529 } else {
2530 data
2531 };
2532
2533 let mut guard = entry.data.write();
2534 *guard = data;
2535 }
2536}
2537
2538pub struct DerivedScanExec {
2548 data: Arc<RwLock<Vec<RecordBatch>>>,
2549 schema: SchemaRef,
2550 properties: PlanProperties,
2551}
2552
2553impl DerivedScanExec {
2554 pub fn new(data: Arc<RwLock<Vec<RecordBatch>>>, schema: SchemaRef) -> Self {
2555 let properties = compute_plan_properties(Arc::clone(&schema));
2556 Self {
2557 data,
2558 schema,
2559 properties,
2560 }
2561 }
2562}
2563
2564impl fmt::Debug for DerivedScanExec {
2565 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2566 f.debug_struct("DerivedScanExec")
2567 .field("schema", &self.schema)
2568 .finish()
2569 }
2570}
2571
2572impl DisplayAs for DerivedScanExec {
2573 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2574 write!(f, "DerivedScanExec")
2575 }
2576}
2577
2578impl ExecutionPlan for DerivedScanExec {
2579 fn name(&self) -> &str {
2580 "DerivedScanExec"
2581 }
2582 fn as_any(&self) -> &dyn Any {
2583 self
2584 }
2585 fn schema(&self) -> SchemaRef {
2586 Arc::clone(&self.schema)
2587 }
2588 fn properties(&self) -> &PlanProperties {
2589 &self.properties
2590 }
2591 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2592 vec![]
2593 }
2594 fn with_new_children(
2595 self: Arc<Self>,
2596 _children: Vec<Arc<dyn ExecutionPlan>>,
2597 ) -> DFResult<Arc<dyn ExecutionPlan>> {
2598 Ok(self)
2599 }
2600 fn execute(
2601 &self,
2602 _partition: usize,
2603 _context: Arc<TaskContext>,
2604 ) -> DFResult<SendableRecordBatchStream> {
2605 let batches = {
2606 let guard = self.data.read();
2607 if guard.is_empty() {
2608 vec![RecordBatch::new_empty(Arc::clone(&self.schema))]
2609 } else {
2610 guard.clone()
2611 }
2612 };
2613 Ok(Box::pin(MemoryStream::try_new(
2614 batches,
2615 Arc::clone(&self.schema),
2616 None,
2617 )?))
2618 }
2619}
2620
2621struct InMemoryExec {
2630 batches: Vec<RecordBatch>,
2631 schema: SchemaRef,
2632 properties: PlanProperties,
2633}
2634
2635impl InMemoryExec {
2636 fn new(batches: Vec<RecordBatch>, schema: SchemaRef) -> Self {
2637 let properties = compute_plan_properties(Arc::clone(&schema));
2638 Self {
2639 batches,
2640 schema,
2641 properties,
2642 }
2643 }
2644}
2645
2646impl fmt::Debug for InMemoryExec {
2647 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2648 f.debug_struct("InMemoryExec")
2649 .field("num_batches", &self.batches.len())
2650 .field("schema", &self.schema)
2651 .finish()
2652 }
2653}
2654
2655impl DisplayAs for InMemoryExec {
2656 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2657 write!(f, "InMemoryExec: batches={}", self.batches.len())
2658 }
2659}
2660
2661impl ExecutionPlan for InMemoryExec {
2662 fn name(&self) -> &str {
2663 "InMemoryExec"
2664 }
2665 fn as_any(&self) -> &dyn Any {
2666 self
2667 }
2668 fn schema(&self) -> SchemaRef {
2669 Arc::clone(&self.schema)
2670 }
2671 fn properties(&self) -> &PlanProperties {
2672 &self.properties
2673 }
2674 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2675 vec![]
2676 }
2677 fn with_new_children(
2678 self: Arc<Self>,
2679 _children: Vec<Arc<dyn ExecutionPlan>>,
2680 ) -> DFResult<Arc<dyn ExecutionPlan>> {
2681 Ok(self)
2682 }
2683 fn execute(
2684 &self,
2685 _partition: usize,
2686 _context: Arc<TaskContext>,
2687 ) -> DFResult<SendableRecordBatchStream> {
2688 Ok(Box::pin(MemoryStream::try_new(
2689 self.batches.clone(),
2690 Arc::clone(&self.schema),
2691 None,
2692 )?))
2693 }
2694}
2695
2696pub(crate) async fn apply_post_fixpoint_chain(
2702 facts: Vec<RecordBatch>,
2703 rule: &FixpointRulePlan,
2704 task_ctx: &Arc<TaskContext>,
2705 strict_probability_domain: bool,
2706 probability_epsilon: f64,
2707) -> DFResult<Vec<RecordBatch>> {
2708 if !rule.has_fold && !rule.has_best_by && !rule.has_priority {
2709 return Ok(facts);
2710 }
2711
2712 let schema = facts
2717 .iter()
2718 .find(|b| b.num_rows() > 0)
2719 .map(|b| b.schema())
2720 .unwrap_or_else(|| Arc::clone(&rule.yield_schema));
2721 let input: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(facts, schema));
2722
2723 let current: Arc<dyn ExecutionPlan> = if rule.has_priority {
2727 let priority_schema = input.schema();
2728 let priority_idx = priority_schema.index_of("__priority").map_err(|_| {
2729 datafusion::common::DataFusionError::Internal(
2730 "PRIORITY rule missing __priority column".to_string(),
2731 )
2732 })?;
2733 Arc::new(PriorityExec::new(
2734 input,
2735 rule.key_column_indices.clone(),
2736 priority_idx,
2737 ))
2738 } else {
2739 input
2740 };
2741
2742 let current: Arc<dyn ExecutionPlan> = if rule.has_fold && !rule.fold_bindings.is_empty() {
2744 Arc::new(FoldExec::new(
2745 current,
2746 rule.key_column_indices.clone(),
2747 rule.fold_bindings.clone(),
2748 strict_probability_domain,
2749 probability_epsilon,
2750 ))
2751 } else {
2752 current
2753 };
2754
2755 let current: Arc<dyn ExecutionPlan> = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
2757 Arc::new(BestByExec::new(
2758 current,
2759 rule.key_column_indices.clone(),
2760 rule.best_by_criteria.clone(),
2761 rule.deterministic,
2762 ))
2763 } else {
2764 current
2765 };
2766
2767 collect_all_partitions(¤t, Arc::clone(task_ctx)).await
2768}
2769
2770pub struct FixpointExec {
2779 rules: Vec<FixpointRulePlan>,
2780 max_iterations: usize,
2781 timeout: Duration,
2782 graph_ctx: Arc<GraphExecutionContext>,
2783 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
2784 storage: Arc<StorageManager>,
2785 schema_info: Arc<UniSchema>,
2786 params: HashMap<String, Value>,
2787 derived_scan_registry: Arc<DerivedScanRegistry>,
2788 output_schema: SchemaRef,
2789 properties: PlanProperties,
2790 metrics: ExecutionPlanMetricsSet,
2791 max_derived_bytes: usize,
2792 derivation_tracker: Option<Arc<ProvenanceStore>>,
2794 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
2796 strict_probability_domain: bool,
2797 probability_epsilon: f64,
2798 exact_probability: bool,
2799 max_bdd_variables: usize,
2800 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
2802 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
2804 top_k_proofs: usize,
2806}
2807
2808impl fmt::Debug for FixpointExec {
2809 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2810 f.debug_struct("FixpointExec")
2811 .field("rules_count", &self.rules.len())
2812 .field("max_iterations", &self.max_iterations)
2813 .field("timeout", &self.timeout)
2814 .field("output_schema", &self.output_schema)
2815 .field("max_derived_bytes", &self.max_derived_bytes)
2816 .finish_non_exhaustive()
2817 }
2818}
2819
2820impl FixpointExec {
2821 #[expect(
2823 clippy::too_many_arguments,
2824 reason = "FixpointExec configuration needs all context"
2825 )]
2826 pub fn new(
2827 rules: Vec<FixpointRulePlan>,
2828 max_iterations: usize,
2829 timeout: Duration,
2830 graph_ctx: Arc<GraphExecutionContext>,
2831 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
2832 storage: Arc<StorageManager>,
2833 schema_info: Arc<UniSchema>,
2834 params: HashMap<String, Value>,
2835 derived_scan_registry: Arc<DerivedScanRegistry>,
2836 output_schema: SchemaRef,
2837 max_derived_bytes: usize,
2838 derivation_tracker: Option<Arc<ProvenanceStore>>,
2839 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
2840 strict_probability_domain: bool,
2841 probability_epsilon: f64,
2842 exact_probability: bool,
2843 max_bdd_variables: usize,
2844 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
2845 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
2846 top_k_proofs: usize,
2847 ) -> Self {
2848 let properties = compute_plan_properties(Arc::clone(&output_schema));
2849 Self {
2850 rules,
2851 max_iterations,
2852 timeout,
2853 graph_ctx,
2854 session_ctx,
2855 storage,
2856 schema_info,
2857 params,
2858 derived_scan_registry,
2859 output_schema,
2860 properties,
2861 metrics: ExecutionPlanMetricsSet::new(),
2862 max_derived_bytes,
2863 derivation_tracker,
2864 iteration_counts,
2865 strict_probability_domain,
2866 probability_epsilon,
2867 exact_probability,
2868 max_bdd_variables,
2869 warnings_slot,
2870 approximate_slot,
2871 top_k_proofs,
2872 }
2873 }
2874
2875 pub fn iteration_counts(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
2877 Arc::clone(&self.iteration_counts)
2878 }
2879}
2880
2881impl DisplayAs for FixpointExec {
2882 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2883 write!(
2884 f,
2885 "FixpointExec: rules=[{}], max_iter={}, timeout={:?}",
2886 self.rules
2887 .iter()
2888 .map(|r| r.name.as_str())
2889 .collect::<Vec<_>>()
2890 .join(", "),
2891 self.max_iterations,
2892 self.timeout,
2893 )
2894 }
2895}
2896
2897impl ExecutionPlan for FixpointExec {
2898 fn name(&self) -> &str {
2899 "FixpointExec"
2900 }
2901
2902 fn as_any(&self) -> &dyn Any {
2903 self
2904 }
2905
2906 fn schema(&self) -> SchemaRef {
2907 Arc::clone(&self.output_schema)
2908 }
2909
2910 fn properties(&self) -> &PlanProperties {
2911 &self.properties
2912 }
2913
2914 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
2915 vec![]
2917 }
2918
2919 fn with_new_children(
2920 self: Arc<Self>,
2921 children: Vec<Arc<dyn ExecutionPlan>>,
2922 ) -> DFResult<Arc<dyn ExecutionPlan>> {
2923 if !children.is_empty() {
2924 return Err(datafusion::error::DataFusionError::Plan(
2925 "FixpointExec has no children".to_string(),
2926 ));
2927 }
2928 Ok(self)
2929 }
2930
2931 fn execute(
2932 &self,
2933 partition: usize,
2934 _context: Arc<TaskContext>,
2935 ) -> DFResult<SendableRecordBatchStream> {
2936 let metrics = BaselineMetrics::new(&self.metrics, partition);
2937
2938 let rules = self
2940 .rules
2941 .iter()
2942 .map(|r| {
2943 FixpointRulePlan {
2947 name: r.name.clone(),
2948 clauses: r
2949 .clauses
2950 .iter()
2951 .map(|c| FixpointClausePlan {
2952 body_logical: c.body_logical.clone(),
2953 is_ref_bindings: c.is_ref_bindings.clone(),
2954 priority: c.priority,
2955 along_bindings: c.along_bindings.clone(),
2956 })
2957 .collect(),
2958 yield_schema: Arc::clone(&r.yield_schema),
2959 key_column_indices: r.key_column_indices.clone(),
2960 priority: r.priority,
2961 has_fold: r.has_fold,
2962 fold_bindings: r.fold_bindings.clone(),
2963 has_best_by: r.has_best_by,
2964 best_by_criteria: r.best_by_criteria.clone(),
2965 has_priority: r.has_priority,
2966 deterministic: r.deterministic,
2967 prob_column_name: r.prob_column_name.clone(),
2968 }
2969 })
2970 .collect();
2971
2972 let max_iterations = self.max_iterations;
2973 let timeout = self.timeout;
2974 let graph_ctx = Arc::clone(&self.graph_ctx);
2975 let session_ctx = Arc::clone(&self.session_ctx);
2976 let storage = Arc::clone(&self.storage);
2977 let schema_info = Arc::clone(&self.schema_info);
2978 let params = self.params.clone();
2979 let registry = Arc::clone(&self.derived_scan_registry);
2980 let output_schema = Arc::clone(&self.output_schema);
2981 let max_derived_bytes = self.max_derived_bytes;
2982 let derivation_tracker = self.derivation_tracker.clone();
2983 let iteration_counts = Arc::clone(&self.iteration_counts);
2984 let strict_probability_domain = self.strict_probability_domain;
2985 let probability_epsilon = self.probability_epsilon;
2986 let exact_probability = self.exact_probability;
2987 let max_bdd_variables = self.max_bdd_variables;
2988 let warnings_slot = Arc::clone(&self.warnings_slot);
2989 let approximate_slot = Arc::clone(&self.approximate_slot);
2990 let top_k_proofs = self.top_k_proofs;
2991
2992 let fut = async move {
2993 run_fixpoint_loop(
2994 rules,
2995 max_iterations,
2996 timeout,
2997 graph_ctx,
2998 session_ctx,
2999 storage,
3000 schema_info,
3001 params,
3002 registry,
3003 output_schema,
3004 max_derived_bytes,
3005 derivation_tracker,
3006 iteration_counts,
3007 strict_probability_domain,
3008 probability_epsilon,
3009 exact_probability,
3010 max_bdd_variables,
3011 warnings_slot,
3012 approximate_slot,
3013 top_k_proofs,
3014 )
3015 .await
3016 };
3017
3018 Ok(Box::pin(FixpointStream {
3019 state: FixpointStreamState::Running(Box::pin(fut)),
3020 schema: Arc::clone(&self.output_schema),
3021 metrics,
3022 }))
3023 }
3024
3025 fn metrics(&self) -> Option<MetricsSet> {
3026 Some(self.metrics.clone_inner())
3027 }
3028}
3029
3030enum FixpointStreamState {
3035 Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
3037 Emitting(Vec<RecordBatch>, usize),
3039 Done,
3041}
3042
3043struct FixpointStream {
3044 state: FixpointStreamState,
3045 schema: SchemaRef,
3046 metrics: BaselineMetrics,
3047}
3048
3049impl Stream for FixpointStream {
3050 type Item = DFResult<RecordBatch>;
3051
3052 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
3053 let this = self.get_mut();
3054 loop {
3055 match &mut this.state {
3056 FixpointStreamState::Running(fut) => match fut.as_mut().poll(cx) {
3057 Poll::Ready(Ok(batches)) => {
3058 if batches.is_empty() {
3059 this.state = FixpointStreamState::Done;
3060 return Poll::Ready(None);
3061 }
3062 this.state = FixpointStreamState::Emitting(batches, 0);
3063 }
3065 Poll::Ready(Err(e)) => {
3066 this.state = FixpointStreamState::Done;
3067 return Poll::Ready(Some(Err(e)));
3068 }
3069 Poll::Pending => return Poll::Pending,
3070 },
3071 FixpointStreamState::Emitting(batches, idx) => {
3072 if *idx >= batches.len() {
3073 this.state = FixpointStreamState::Done;
3074 return Poll::Ready(None);
3075 }
3076 let batch = batches[*idx].clone();
3077 *idx += 1;
3078 this.metrics.record_output(batch.num_rows());
3079 return Poll::Ready(Some(Ok(batch)));
3080 }
3081 FixpointStreamState::Done => return Poll::Ready(None),
3082 }
3083 }
3084 }
3085}
3086
3087impl RecordBatchStream for FixpointStream {
3088 fn schema(&self) -> SchemaRef {
3089 Arc::clone(&self.schema)
3090 }
3091}
3092
3093#[cfg(test)]
3098mod tests {
3099 use super::*;
3100 use arrow_array::{Float64Array, Int64Array, StringArray};
3101 use arrow_schema::{DataType, Field, Schema};
3102
3103 fn test_schema() -> SchemaRef {
3104 Arc::new(Schema::new(vec![
3105 Field::new("name", DataType::Utf8, true),
3106 Field::new("value", DataType::Int64, true),
3107 ]))
3108 }
3109
3110 fn make_batch(names: &[&str], values: &[i64]) -> RecordBatch {
3111 RecordBatch::try_new(
3112 test_schema(),
3113 vec![
3114 Arc::new(StringArray::from(
3115 names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
3116 )),
3117 Arc::new(Int64Array::from(values.to_vec())),
3118 ],
3119 )
3120 .unwrap()
3121 }
3122
3123 #[tokio::test]
3126 async fn test_fixpoint_state_empty_facts_adds_all() {
3127 let schema = test_schema();
3128 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3129
3130 let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
3131 let changed = state.merge_delta(vec![batch], None).await.unwrap();
3132
3133 assert!(changed);
3134 assert_eq!(state.all_facts().len(), 1);
3135 assert_eq!(state.all_facts()[0].num_rows(), 3);
3136 assert_eq!(state.all_delta().len(), 1);
3137 assert_eq!(state.all_delta()[0].num_rows(), 3);
3138 }
3139
3140 #[tokio::test]
3141 async fn test_fixpoint_state_exact_duplicates_excluded() {
3142 let schema = test_schema();
3143 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3144
3145 let batch1 = make_batch(&["a", "b"], &[1, 2]);
3146 state.merge_delta(vec![batch1], None).await.unwrap();
3147
3148 let batch2 = make_batch(&["a", "b"], &[1, 2]);
3150 let changed = state.merge_delta(vec![batch2], None).await.unwrap();
3151 assert!(!changed);
3152 assert!(
3153 state.all_delta().is_empty() || state.all_delta().iter().all(|b| b.num_rows() == 0)
3154 );
3155 }
3156
3157 #[tokio::test]
3158 async fn test_fixpoint_state_partial_overlap() {
3159 let schema = test_schema();
3160 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3161
3162 let batch1 = make_batch(&["a", "b"], &[1, 2]);
3163 state.merge_delta(vec![batch1], None).await.unwrap();
3164
3165 let batch2 = make_batch(&["a", "c"], &[1, 3]);
3167 let changed = state.merge_delta(vec![batch2], None).await.unwrap();
3168 assert!(changed);
3169
3170 let delta_rows: usize = state.all_delta().iter().map(|b| b.num_rows()).sum();
3172 assert_eq!(delta_rows, 1);
3173
3174 let total_rows: usize = state.all_facts().iter().map(|b| b.num_rows()).sum();
3176 assert_eq!(total_rows, 3);
3177 }
3178
3179 #[tokio::test]
3180 async fn test_fixpoint_state_convergence() {
3181 let schema = test_schema();
3182 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
3183
3184 let batch = make_batch(&["a"], &[1]);
3185 state.merge_delta(vec![batch], None).await.unwrap();
3186
3187 let changed = state.merge_delta(vec![], None).await.unwrap();
3189 assert!(!changed);
3190 assert!(state.is_converged());
3191 }
3192
3193 #[test]
3196 fn test_row_dedup_persistent_across_calls() {
3197 let schema = test_schema();
3200 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
3201
3202 let batch1 = make_batch(&["a", "b"], &[1, 2]);
3203 let delta1 = rd.compute_delta(&[batch1], &schema).unwrap();
3204 let rows1: usize = delta1.iter().map(|b| b.num_rows()).sum();
3206 assert_eq!(rows1, 2);
3207
3208 let batch2 = make_batch(&["a", "b"], &[1, 2]);
3210 let delta2 = rd.compute_delta(&[batch2], &schema).unwrap();
3211 let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
3212 assert_eq!(rows2, 0);
3213
3214 let batch3 = make_batch(&["a", "c"], &[1, 3]);
3216 let delta3 = rd.compute_delta(&[batch3], &schema).unwrap();
3217 let rows3: usize = delta3.iter().map(|b| b.num_rows()).sum();
3218 assert_eq!(rows3, 1);
3219 }
3220
3221 #[test]
3222 fn test_row_dedup_null_handling() {
3223 use arrow_array::StringArray;
3224 use arrow_schema::{DataType, Field, Schema};
3225
3226 let schema: SchemaRef = Arc::new(Schema::new(vec![
3227 Field::new("a", DataType::Utf8, true),
3228 Field::new("b", DataType::Int64, true),
3229 ]));
3230 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
3231
3232 let batch_nulls = RecordBatch::try_new(
3234 Arc::clone(&schema),
3235 vec![
3236 Arc::new(StringArray::from(vec![None::<&str>, None::<&str>])),
3237 Arc::new(arrow_array::Int64Array::from(vec![1i64, 1i64])),
3238 ],
3239 )
3240 .unwrap();
3241 let delta = rd.compute_delta(&[batch_nulls], &schema).unwrap();
3242 let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
3243 assert_eq!(rows, 1, "two identical NULL rows should be deduped to one");
3244
3245 let batch_diff = RecordBatch::try_new(
3247 Arc::clone(&schema),
3248 vec![
3249 Arc::new(StringArray::from(vec![None::<&str>])),
3250 Arc::new(arrow_array::Int64Array::from(vec![2i64])),
3251 ],
3252 )
3253 .unwrap();
3254 let delta2 = rd.compute_delta(&[batch_diff], &schema).unwrap();
3255 let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
3256 assert_eq!(rows2, 1, "(NULL, 2) is distinct from (NULL, 1)");
3257 }
3258
3259 #[test]
3260 fn test_row_dedup_within_candidate_dedup() {
3261 let schema = test_schema();
3263 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
3264
3265 let batch = make_batch(&["a", "a", "b"], &[1, 1, 2]);
3267 let delta = rd.compute_delta(&[batch], &schema).unwrap();
3268 let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
3269 assert_eq!(rows, 2, "within-batch dup should be collapsed: a:1, b:2");
3270 }
3271
3272 #[test]
3275 fn test_round_float_columns_near_duplicates() {
3276 let schema = Arc::new(Schema::new(vec![
3277 Field::new("name", DataType::Utf8, true),
3278 Field::new("dist", DataType::Float64, true),
3279 ]));
3280 let batch = RecordBatch::try_new(
3281 Arc::clone(&schema),
3282 vec![
3283 Arc::new(StringArray::from(vec![Some("a"), Some("a")])),
3284 Arc::new(Float64Array::from(vec![1.0000000000001, 1.0000000000002])),
3285 ],
3286 )
3287 .unwrap();
3288
3289 let rounded = round_float_columns(&[batch]);
3290 assert_eq!(rounded.len(), 1);
3291 let col = rounded[0]
3292 .column(1)
3293 .as_any()
3294 .downcast_ref::<Float64Array>()
3295 .unwrap();
3296 assert_eq!(col.value(0), col.value(1));
3298 }
3299
3300 #[test]
3303 fn test_registry_write_read_round_trip() {
3304 let schema = test_schema();
3305 let data = Arc::new(RwLock::new(Vec::new()));
3306 let mut reg = DerivedScanRegistry::new();
3307 reg.add(DerivedScanEntry {
3308 scan_index: 0,
3309 rule_name: "reachable".into(),
3310 is_self_ref: true,
3311 data: Arc::clone(&data),
3312 schema: Arc::clone(&schema),
3313 });
3314
3315 let batch = make_batch(&["x"], &[42]);
3316 reg.write_data(0, vec![batch.clone()]);
3317
3318 let entry = reg.get(0).unwrap();
3319 let guard = entry.data.read();
3320 assert_eq!(guard.len(), 1);
3321 assert_eq!(guard[0].num_rows(), 1);
3322 }
3323
3324 #[test]
3325 fn test_registry_entries_for_rule() {
3326 let schema = test_schema();
3327 let mut reg = DerivedScanRegistry::new();
3328 reg.add(DerivedScanEntry {
3329 scan_index: 0,
3330 rule_name: "r1".into(),
3331 is_self_ref: true,
3332 data: Arc::new(RwLock::new(Vec::new())),
3333 schema: Arc::clone(&schema),
3334 });
3335 reg.add(DerivedScanEntry {
3336 scan_index: 1,
3337 rule_name: "r2".into(),
3338 is_self_ref: false,
3339 data: Arc::new(RwLock::new(Vec::new())),
3340 schema: Arc::clone(&schema),
3341 });
3342 reg.add(DerivedScanEntry {
3343 scan_index: 2,
3344 rule_name: "r1".into(),
3345 is_self_ref: false,
3346 data: Arc::new(RwLock::new(Vec::new())),
3347 schema: Arc::clone(&schema),
3348 });
3349
3350 assert_eq!(reg.entries_for_rule("r1").len(), 2);
3351 assert_eq!(reg.entries_for_rule("r2").len(), 1);
3352 assert_eq!(reg.entries_for_rule("r3").len(), 0);
3353 }
3354
3355 #[test]
3358 fn test_monotonic_agg_update_and_stability() {
3359 use crate::query::df_graph::locy_fold::FoldAggKind;
3360
3361 let bindings = vec![MonotonicFoldBinding {
3362 fold_name: "total".into(),
3363 kind: FoldAggKind::Sum,
3364 input_col_index: 1,
3365 }];
3366 let mut agg = MonotonicAggState::new(bindings);
3367
3368 let batch = make_batch(&["a"], &[10]);
3370 agg.snapshot();
3371 let changed = agg.update(&[0], &[batch], false).unwrap();
3372 assert!(changed);
3373 assert!(!agg.is_stable()); agg.snapshot();
3377 let changed = agg.update(&[0], &[], false).unwrap();
3378 assert!(!changed);
3379 assert!(agg.is_stable());
3380 }
3381
3382 #[tokio::test]
3385 async fn test_memory_limit_exceeded() {
3386 let schema = test_schema();
3387 let mut state = FixpointState::new("test".into(), schema, vec![0], 1, None, false);
3389
3390 let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
3391 let result = state.merge_delta(vec![batch], None).await;
3392 assert!(result.is_err());
3393 let err = result.unwrap_err().to_string();
3394 assert!(err.contains("memory limit"), "Error was: {}", err);
3395 }
3396
3397 #[tokio::test]
3400 async fn test_fixpoint_stream_emitting() {
3401 use futures::StreamExt;
3402
3403 let schema = test_schema();
3404 let batch1 = make_batch(&["a"], &[1]);
3405 let batch2 = make_batch(&["b"], &[2]);
3406
3407 let metrics = ExecutionPlanMetricsSet::new();
3408 let baseline = BaselineMetrics::new(&metrics, 0);
3409
3410 let mut stream = FixpointStream {
3411 state: FixpointStreamState::Emitting(vec![batch1, batch2], 0),
3412 schema,
3413 metrics: baseline,
3414 };
3415
3416 let stream = Pin::new(&mut stream);
3417 let batches: Vec<RecordBatch> = stream.filter_map(|r| async { r.ok() }).collect().await;
3418
3419 assert_eq!(batches.len(), 2);
3420 assert_eq!(batches[0].num_rows(), 1);
3421 assert_eq!(batches[1].num_rows(), 1);
3422 }
3423
3424 fn make_f64_batch(names: &[&str], values: &[f64]) -> RecordBatch {
3427 let schema = Arc::new(Schema::new(vec![
3428 Field::new("name", DataType::Utf8, true),
3429 Field::new("value", DataType::Float64, true),
3430 ]));
3431 RecordBatch::try_new(
3432 schema,
3433 vec![
3434 Arc::new(StringArray::from(
3435 names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
3436 )),
3437 Arc::new(Float64Array::from(values.to_vec())),
3438 ],
3439 )
3440 .unwrap()
3441 }
3442
3443 fn make_nor_binding() -> Vec<MonotonicFoldBinding> {
3444 use crate::query::df_graph::locy_fold::FoldAggKind;
3445 vec![MonotonicFoldBinding {
3446 fold_name: "prob".into(),
3447 kind: FoldAggKind::Nor,
3448 input_col_index: 1,
3449 }]
3450 }
3451
3452 fn make_prod_binding() -> Vec<MonotonicFoldBinding> {
3453 use crate::query::df_graph::locy_fold::FoldAggKind;
3454 vec![MonotonicFoldBinding {
3455 fold_name: "prob".into(),
3456 kind: FoldAggKind::Prod,
3457 input_col_index: 1,
3458 }]
3459 }
3460
3461 fn acc_key(name: &str) -> (Vec<ScalarKey>, String) {
3462 (vec![ScalarKey::Utf8(name.to_string())], "prob".to_string())
3463 }
3464
3465 #[test]
3466 fn test_monotonic_nor_first_update() {
3467 let mut agg = MonotonicAggState::new(make_nor_binding());
3468 let batch = make_f64_batch(&["a"], &[0.3]);
3469 let changed = agg.update(&[0], &[batch], false).unwrap();
3470 assert!(changed);
3471 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3472 assert!((val - 0.3).abs() < 1e-10, "expected 0.3, got {}", val);
3473 }
3474
3475 #[test]
3476 fn test_monotonic_nor_two_updates() {
3477 let mut agg = MonotonicAggState::new(make_nor_binding());
3479 let batch1 = make_f64_batch(&["a"], &[0.3]);
3480 agg.update(&[0], &[batch1], false).unwrap();
3481 let batch2 = make_f64_batch(&["a"], &[0.5]);
3482 agg.update(&[0], &[batch2], false).unwrap();
3483 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3484 assert!((val - 0.65).abs() < 1e-10, "expected 0.65, got {}", val);
3485 }
3486
3487 #[test]
3488 fn test_monotonic_prod_first_update() {
3489 let mut agg = MonotonicAggState::new(make_prod_binding());
3490 let batch = make_f64_batch(&["a"], &[0.6]);
3491 let changed = agg.update(&[0], &[batch], false).unwrap();
3492 assert!(changed);
3493 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3494 assert!((val - 0.6).abs() < 1e-10, "expected 0.6, got {}", val);
3495 }
3496
3497 #[test]
3498 fn test_monotonic_prod_two_updates() {
3499 let mut agg = MonotonicAggState::new(make_prod_binding());
3501 let batch1 = make_f64_batch(&["a"], &[0.6]);
3502 agg.update(&[0], &[batch1], false).unwrap();
3503 let batch2 = make_f64_batch(&["a"], &[0.8]);
3504 agg.update(&[0], &[batch2], false).unwrap();
3505 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3506 assert!((val - 0.48).abs() < 1e-10, "expected 0.48, got {}", val);
3507 }
3508
3509 #[test]
3510 fn test_monotonic_nor_stability() {
3511 let mut agg = MonotonicAggState::new(make_nor_binding());
3512 let batch = make_f64_batch(&["a"], &[0.3]);
3513 agg.update(&[0], &[batch], false).unwrap();
3514 agg.snapshot();
3515 let changed = agg.update(&[0], &[], false).unwrap();
3516 assert!(!changed);
3517 assert!(agg.is_stable());
3518 }
3519
3520 #[test]
3521 fn test_monotonic_prod_stability() {
3522 let mut agg = MonotonicAggState::new(make_prod_binding());
3523 let batch = make_f64_batch(&["a"], &[0.6]);
3524 agg.update(&[0], &[batch], false).unwrap();
3525 agg.snapshot();
3526 let changed = agg.update(&[0], &[], false).unwrap();
3527 assert!(!changed);
3528 assert!(agg.is_stable());
3529 }
3530
3531 #[test]
3532 fn test_monotonic_nor_multi_group() {
3533 let mut agg = MonotonicAggState::new(make_nor_binding());
3535 let batch1 = make_f64_batch(&["a", "b"], &[0.3, 0.5]);
3536 agg.update(&[0], &[batch1], false).unwrap();
3537 let batch2 = make_f64_batch(&["a", "b"], &[0.5, 0.2]);
3538 agg.update(&[0], &[batch2], false).unwrap();
3539
3540 let val_a = agg.get_accumulator(&acc_key("a")).unwrap();
3541 let val_b = agg.get_accumulator(&acc_key("b")).unwrap();
3542 assert!(
3543 (val_a - 0.65).abs() < 1e-10,
3544 "expected a=0.65, got {}",
3545 val_a
3546 );
3547 assert!((val_b - 0.6).abs() < 1e-10, "expected b=0.6, got {}", val_b);
3548 }
3549
3550 #[test]
3551 fn test_monotonic_prod_zero_absorbing() {
3552 let mut agg = MonotonicAggState::new(make_prod_binding());
3554 let batch1 = make_f64_batch(&["a"], &[0.5]);
3555 agg.update(&[0], &[batch1], false).unwrap();
3556 let batch2 = make_f64_batch(&["a"], &[0.0]);
3557 agg.update(&[0], &[batch2], false).unwrap();
3558
3559 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3560 assert!((val - 0.0).abs() < 1e-10, "expected 0.0, got {}", val);
3561
3562 agg.snapshot();
3564 let batch3 = make_f64_batch(&["a"], &[0.5]);
3565 let changed = agg.update(&[0], &[batch3], false).unwrap();
3566 assert!(!changed);
3567 assert!(agg.is_stable());
3568 }
3569
3570 #[test]
3571 fn test_monotonic_nor_clamping() {
3572 let mut agg = MonotonicAggState::new(make_nor_binding());
3574 let batch = make_f64_batch(&["a"], &[1.5]);
3575 agg.update(&[0], &[batch], false).unwrap();
3576 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3577 assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
3578 }
3579
3580 #[test]
3581 fn test_monotonic_nor_absorbing() {
3582 let mut agg = MonotonicAggState::new(make_nor_binding());
3584 let batch1 = make_f64_batch(&["a"], &[0.3]);
3585 agg.update(&[0], &[batch1], false).unwrap();
3586 let batch2 = make_f64_batch(&["a"], &[1.0]);
3587 agg.update(&[0], &[batch2], false).unwrap();
3588 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3589 assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
3590 }
3591
3592 #[test]
3595 fn test_monotonic_agg_strict_nor_rejects() {
3596 let mut agg = MonotonicAggState::new(make_nor_binding());
3597 let batch = make_f64_batch(&["a"], &[1.5]);
3598 let result = agg.update(&[0], &[batch], true);
3599 assert!(result.is_err());
3600 let err = result.unwrap_err().to_string();
3601 assert!(
3602 err.contains("strict_probability_domain"),
3603 "Expected strict error, got: {}",
3604 err
3605 );
3606 }
3607
3608 #[test]
3609 fn test_monotonic_agg_strict_prod_rejects() {
3610 let mut agg = MonotonicAggState::new(make_prod_binding());
3611 let batch = make_f64_batch(&["a"], &[2.0]);
3612 let result = agg.update(&[0], &[batch], true);
3613 assert!(result.is_err());
3614 let err = result.unwrap_err().to_string();
3615 assert!(
3616 err.contains("strict_probability_domain"),
3617 "Expected strict error, got: {}",
3618 err
3619 );
3620 }
3621
3622 #[test]
3623 fn test_monotonic_agg_strict_accepts_valid() {
3624 let mut agg = MonotonicAggState::new(make_nor_binding());
3625 let batch = make_f64_batch(&["a"], &[0.5]);
3626 let result = agg.update(&[0], &[batch], true);
3627 assert!(result.is_ok());
3628 let val = agg.get_accumulator(&acc_key("a")).unwrap();
3629 assert!((val - 0.5).abs() < 1e-10, "expected 0.5, got {}", val);
3630 }
3631
3632 fn make_vid_prob_batch(vids: &[u64], probs: &[f64]) -> RecordBatch {
3635 use arrow_array::UInt64Array;
3636 let schema = Arc::new(Schema::new(vec![
3637 Field::new("vid", DataType::UInt64, true),
3638 Field::new("prob", DataType::Float64, true),
3639 ]));
3640 RecordBatch::try_new(
3641 schema,
3642 vec![
3643 Arc::new(UInt64Array::from(vids.to_vec())),
3644 Arc::new(Float64Array::from(probs.to_vec())),
3645 ],
3646 )
3647 .unwrap()
3648 }
3649
3650 #[test]
3651 fn test_prob_complement_basic() {
3652 let body = make_vid_prob_batch(&[1, 2], &[0.9, 0.8]);
3654 let neg = make_vid_prob_batch(&[1], &[0.7]);
3655 let join_cols = vec![("vid".to_string(), "vid".to_string())];
3656 let result = apply_prob_complement_composite(
3657 vec![body],
3658 &[neg],
3659 &join_cols,
3660 "prob",
3661 "__complement_0",
3662 )
3663 .unwrap();
3664 assert_eq!(result.len(), 1);
3665 let batch = &result[0];
3666 let complement = batch
3667 .column_by_name("__complement_0")
3668 .unwrap()
3669 .as_any()
3670 .downcast_ref::<Float64Array>()
3671 .unwrap();
3672 assert!(
3674 (complement.value(0) - 0.3).abs() < 1e-10,
3675 "expected 0.3, got {}",
3676 complement.value(0)
3677 );
3678 assert!(
3680 (complement.value(1) - 1.0).abs() < 1e-10,
3681 "expected 1.0, got {}",
3682 complement.value(1)
3683 );
3684 }
3685
3686 #[test]
3687 fn test_prob_complement_noisy_or_duplicates() {
3688 let body = make_vid_prob_batch(&[1], &[0.9]);
3692 let neg = make_vid_prob_batch(&[1, 1], &[0.3, 0.5]);
3693 let join_cols = vec![("vid".to_string(), "vid".to_string())];
3694 let result = apply_prob_complement_composite(
3695 vec![body],
3696 &[neg],
3697 &join_cols,
3698 "prob",
3699 "__complement_0",
3700 )
3701 .unwrap();
3702 let batch = &result[0];
3703 let complement = batch
3704 .column_by_name("__complement_0")
3705 .unwrap()
3706 .as_any()
3707 .downcast_ref::<Float64Array>()
3708 .unwrap();
3709 assert!(
3710 (complement.value(0) - 0.35).abs() < 1e-10,
3711 "expected 0.35, got {}",
3712 complement.value(0)
3713 );
3714 }
3715
3716 #[test]
3717 fn test_prob_complement_empty_neg() {
3718 let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
3720 let join_cols = vec![("vid".to_string(), "vid".to_string())];
3721 let result =
3722 apply_prob_complement_composite(vec![body], &[], &join_cols, "prob", "__complement_0")
3723 .unwrap();
3724 let batch = &result[0];
3725 let complement = batch
3726 .column_by_name("__complement_0")
3727 .unwrap()
3728 .as_any()
3729 .downcast_ref::<Float64Array>()
3730 .unwrap();
3731 for i in 0..2 {
3732 assert!(
3733 (complement.value(i) - 1.0).abs() < 1e-10,
3734 "row {}: expected 1.0, got {}",
3735 i,
3736 complement.value(i)
3737 );
3738 }
3739 }
3740
3741 #[test]
3742 fn test_anti_join_basic() {
3743 use arrow_array::UInt64Array;
3745 let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
3746 let neg = make_vid_prob_batch(&[2], &[0.0]);
3747 let join_cols = vec![("vid".to_string(), "vid".to_string())];
3748 let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
3749 assert_eq!(result.len(), 1);
3750 let batch = &result[0];
3751 assert_eq!(batch.num_rows(), 2);
3752 let vids = batch
3753 .column_by_name("vid")
3754 .unwrap()
3755 .as_any()
3756 .downcast_ref::<UInt64Array>()
3757 .unwrap();
3758 assert_eq!(vids.value(0), 1);
3759 assert_eq!(vids.value(1), 3);
3760 }
3761
3762 #[test]
3763 fn test_anti_join_empty_neg() {
3764 let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
3766 let join_cols = vec![("vid".to_string(), "vid".to_string())];
3767 let result = apply_anti_join_composite(vec![body], &[], &join_cols).unwrap();
3768 assert_eq!(result.len(), 1);
3769 assert_eq!(result[0].num_rows(), 3);
3770 }
3771
3772 #[test]
3773 fn test_anti_join_all_excluded() {
3774 let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
3776 let neg = make_vid_prob_batch(&[1, 2], &[0.0, 0.0]);
3777 let join_cols = vec![("vid".to_string(), "vid".to_string())];
3778 let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
3779 let total: usize = result.iter().map(|b| b.num_rows()).sum();
3780 assert_eq!(total, 0);
3781 }
3782
3783 #[test]
3784 fn test_multiply_prob_single_complement() {
3785 let body = make_vid_prob_batch(&[1], &[0.8]);
3787 let complement_arr = Float64Array::from(vec![0.5]);
3789 let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
3790 cols.push(Arc::new(complement_arr));
3791 let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
3792 fields.push(Arc::new(Field::new(
3793 "__complement_0",
3794 DataType::Float64,
3795 true,
3796 )));
3797 let schema = Arc::new(Schema::new(fields));
3798 let batch = RecordBatch::try_new(schema, cols).unwrap();
3799
3800 let result =
3801 multiply_prob_factors(vec![batch], Some("prob"), &["__complement_0".to_string()])
3802 .unwrap();
3803 assert_eq!(result.len(), 1);
3804 let out = &result[0];
3805 assert!(out.column_by_name("__complement_0").is_none());
3807 let prob = out
3808 .column_by_name("prob")
3809 .unwrap()
3810 .as_any()
3811 .downcast_ref::<Float64Array>()
3812 .unwrap();
3813 assert!(
3814 (prob.value(0) - 0.4).abs() < 1e-10,
3815 "expected 0.4, got {}",
3816 prob.value(0)
3817 );
3818 }
3819
3820 #[test]
3821 fn test_multiply_prob_multiple_complements() {
3822 let body = make_vid_prob_batch(&[1], &[0.8]);
3824 let c1 = Float64Array::from(vec![0.5]);
3825 let c2 = Float64Array::from(vec![0.6]);
3826 let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
3827 cols.push(Arc::new(c1));
3828 cols.push(Arc::new(c2));
3829 let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
3830 fields.push(Arc::new(Field::new("__c1", DataType::Float64, true)));
3831 fields.push(Arc::new(Field::new("__c2", DataType::Float64, true)));
3832 let schema = Arc::new(Schema::new(fields));
3833 let batch = RecordBatch::try_new(schema, cols).unwrap();
3834
3835 let result = multiply_prob_factors(
3836 vec![batch],
3837 Some("prob"),
3838 &["__c1".to_string(), "__c2".to_string()],
3839 )
3840 .unwrap();
3841 let out = &result[0];
3842 assert!(out.column_by_name("__c1").is_none());
3843 assert!(out.column_by_name("__c2").is_none());
3844 let prob = out
3845 .column_by_name("prob")
3846 .unwrap()
3847 .as_any()
3848 .downcast_ref::<Float64Array>()
3849 .unwrap();
3850 assert!(
3851 (prob.value(0) - 0.24).abs() < 1e-10,
3852 "expected 0.24, got {}",
3853 prob.value(0)
3854 );
3855 }
3856
3857 #[test]
3858 fn test_multiply_prob_no_prob_column() {
3859 use arrow_array::UInt64Array;
3861 let schema = Arc::new(Schema::new(vec![
3862 Field::new("vid", DataType::UInt64, true),
3863 Field::new("__c1", DataType::Float64, true),
3864 ]));
3865 let batch = RecordBatch::try_new(
3866 schema,
3867 vec![
3868 Arc::new(UInt64Array::from(vec![1u64])),
3869 Arc::new(Float64Array::from(vec![0.7])),
3870 ],
3871 )
3872 .unwrap();
3873
3874 let result = multiply_prob_factors(vec![batch], None, &["__c1".to_string()]).unwrap();
3875 let out = &result[0];
3876 assert!(out.column_by_name("__c1").is_none());
3878 assert_eq!(out.num_columns(), 1);
3880 }
3881}