1use crate::query::df_graph::GraphExecutionContext;
10use crate::query::df_graph::common::{
11 ScalarKey, arrow_err, collect_all_partitions, compute_plan_properties, execute_subplan,
12 execute_subplan_collecting, extract_scalar_key,
13};
14use crate::query::df_graph::locy_best_by::{BestByExec, SortCriterion};
15use crate::query::df_graph::locy_errors::LocyRuntimeError;
16use crate::query::df_graph::locy_explain::{
17 ProofTerm, ProvenanceAnnotation, ProvenanceStore, compute_proof_probability,
18};
19use crate::query::df_graph::locy_fold::{FoldBinding, FoldExec};
20use crate::query::df_graph::locy_priority::PriorityExec;
21use crate::query::df_graph::locy_profile::LocyProfileCollector;
22use crate::query::df_graph::locy_program::interruption;
23use crate::query::executor::core::OperatorStats;
24use crate::query::planner::LogicalPlan;
25use arrow_array::RecordBatch;
26use arrow_row::{RowConverter, SortField};
27use arrow_schema::SchemaRef;
28use datafusion::common::JoinType;
29use datafusion::common::Result as DFResult;
30use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
31use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
32use datafusion::physical_plan::memory::MemoryStream;
33use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
34use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
35use futures::Stream;
36use parking_lot::RwLock;
37use std::any::Any;
38use std::collections::{HashMap, HashSet};
39use std::fmt;
40use std::pin::Pin;
41use std::sync::{Arc, RwLock as StdRwLock};
42use std::task::{Context, Poll};
43use std::time::{Duration, Instant};
44use uni_common::Value;
45use uni_common::core::schema::Schema as UniSchema;
46use uni_cypher::ast::Expr;
47use uni_locy::{
48 ClassifierRegistry, ModelInvocation, ModelInvocationCache, RuntimeWarning, RuntimeWarningCode,
49 SemiringKind,
50};
51use uni_store::storage::manager::StorageManager;
52
53#[derive(Debug)]
63pub struct DerivedScanEntry {
64 pub scan_index: usize,
66 pub rule_name: String,
68 pub is_self_ref: bool,
70 pub data: Arc<RwLock<Vec<RecordBatch>>>,
72 pub schema: SchemaRef,
74}
75
76#[derive(Debug, Default)]
83pub struct DerivedScanRegistry {
84 entries: Vec<DerivedScanEntry>,
85}
86
87impl DerivedScanRegistry {
88 pub fn new() -> Self {
90 Self::default()
91 }
92
93 pub fn add(&mut self, entry: DerivedScanEntry) {
95 self.entries.push(entry);
96 }
97
98 pub fn get(&self, scan_index: usize) -> Option<&DerivedScanEntry> {
100 self.entries.iter().find(|e| e.scan_index == scan_index)
101 }
102
103 pub fn write_data(&self, scan_index: usize, batches: Vec<RecordBatch>) {
105 if let Some(entry) = self.get(scan_index) {
106 let mut guard = entry.data.write();
107 *guard = batches;
108 }
109 }
110
111 pub fn entries_for_rule(&self, rule_name: &str) -> Vec<&DerivedScanEntry> {
113 self.entries
114 .iter()
115 .filter(|e| e.rule_name == rule_name)
116 .collect()
117 }
118}
119
120#[derive(Debug, Clone)]
131pub struct MonotonicFoldBinding {
132 pub fold_name: String,
133 pub aggregate: std::sync::Arc<dyn uni_plugin::traits::locy::LocyAggregate>,
134 pub input_col_index: usize,
135 pub input_col_name: Option<String>,
137}
138
139#[derive(Debug)]
145pub struct MonotonicAggState {
146 accumulators: HashMap<(Vec<ScalarKey>, String), f64>,
148 prev_snapshot: HashMap<(Vec<ScalarKey>, String), f64>,
150 bindings: Vec<MonotonicFoldBinding>,
152}
153
154impl MonotonicAggState {
155 pub fn new(bindings: Vec<MonotonicFoldBinding>) -> Self {
157 Self {
158 accumulators: HashMap::new(),
159 prev_snapshot: HashMap::new(),
160 bindings,
161 }
162 }
163
164 pub fn update(
190 &mut self,
191 key_indices: &[usize],
192 delta_batches: &[RecordBatch],
193 strict: bool,
194 semiring_kind: SemiringKind,
195 ) -> DFResult<bool> {
196 let mut changed = false;
197 for batch in delta_batches {
198 for row_idx in 0..batch.num_rows() {
199 let group_key = extract_scalar_key(batch, key_indices, row_idx);
200 for binding in &self.bindings {
201 let idx = binding
202 .input_col_name
203 .as_ref()
204 .and_then(|name| batch.schema().index_of(name).ok())
205 .unwrap_or(binding.input_col_index);
206 if idx >= batch.num_columns() {
207 continue;
208 }
209 let col = batch.column(idx);
210 let val = extract_f64(col.as_ref(), row_idx);
211 if let Some(val) = val {
212 let map_key = (group_key.clone(), binding.fold_name.clone());
213 let initial = binding.aggregate.initial_accum_f64().unwrap_or(0.0);
214 let entry = self.accumulators.entry(map_key).or_insert(initial);
215 let old = *entry;
216 if matches!(semiring_kind, SemiringKind::MaxMinProb)
227 && binding.aggregate.is_probability_aggregate()
228 {
229 use uni_locy::LocySemiring;
230 let sr = uni_locy::MaxMinProb;
231 let is_nor = binding.aggregate.is_noisy_or();
232 let label = if is_nor { "MNOR" } else { "MPROD" };
233 if strict && !(0.0..=1.0).contains(&val) {
234 return Err(datafusion::error::DataFusionError::Execution(
235 format!(
236 "strict_probability_domain: {label} input {val} is outside [0, 1]"
237 ),
238 ));
239 }
240 if !strict && !(0.0..=1.0).contains(&val) {
241 tracing::warn!(
242 "{label} input {val} outside [0,1], clamped to {}",
243 val.clamp(0.0, 1.0)
244 );
245 }
246 let p = val.clamp(0.0, 1.0);
247 *entry = if is_nor {
249 sr.plus(entry, &p)
250 } else {
251 sr.times(entry, &p)
252 };
253 if (*entry - old).abs() > f64::EPSILON {
254 changed = true;
255 }
256 continue;
257 }
258 match binding.aggregate.update_step(*entry, val, strict) {
259 Ok(new_val) => {
260 *entry = new_val;
261 if (*entry - old).abs() > f64::EPSILON {
262 changed = true;
263 }
264 }
265 Err(e) if e.code == uni_plugin::FnError::CODE_UNKNOWN_FUNCTION => {
266 }
270 Err(e) => {
271 return Err(datafusion::error::DataFusionError::Execution(
274 e.message,
275 ));
276 }
277 }
278 }
279 }
280 }
281 }
282 Ok(changed)
283 }
284
285 pub fn snapshot(&mut self) {
287 self.prev_snapshot = self.accumulators.clone();
288 }
289
290 pub fn is_stable(&self) -> bool {
292 if self.accumulators.len() != self.prev_snapshot.len() {
293 return false;
294 }
295 for (key, val) in &self.accumulators {
296 match self.prev_snapshot.get(key) {
297 Some(prev) if (*val - *prev).abs() <= f64::EPSILON => {}
298 _ => return false,
299 }
300 }
301 true
302 }
303
304 #[cfg(test)]
306 pub(crate) fn get_accumulator(&self, key: &(Vec<ScalarKey>, String)) -> Option<f64> {
307 self.accumulators.get(key).copied()
308 }
309}
310
311fn extract_f64(col: &dyn arrow_array::Array, row_idx: usize) -> Option<f64> {
313 if col.is_null(row_idx) {
314 return None;
315 }
316 if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Float64Array>() {
317 Some(arr.value(row_idx))
318 } else {
319 col.as_any()
320 .downcast_ref::<arrow_array::Int64Array>()
321 .map(|arr| arr.value(row_idx) as f64)
322 }
323}
324
325struct RowDedupState {
335 converter: RowConverter,
336 seen: HashSet<Box<[u8]>>,
337}
338
339impl RowDedupState {
340 fn try_new(schema: &SchemaRef) -> Option<Self> {
345 let fields: Vec<SortField> = schema
346 .fields()
347 .iter()
348 .map(|f| SortField::new(f.data_type().clone()))
349 .collect();
350 match RowConverter::new(fields) {
351 Ok(converter) => Some(Self {
352 converter,
353 seen: HashSet::new(),
354 }),
355 Err(e) => {
356 tracing::warn!(
357 "RowDedupState: RowConverter unsupported for schema, falling back to legacy dedup: {}",
358 e
359 );
360 None
361 }
362 }
363 }
364
365 fn ingest_existing(&mut self, facts: &[RecordBatch], _schema: &SchemaRef) {
370 self.seen.clear();
371 for batch in facts {
372 if batch.num_rows() == 0 {
373 continue;
374 }
375 let arrays: Vec<_> = batch.columns().to_vec();
376 if let Ok(rows) = self.converter.convert_columns(&arrays) {
377 for row_idx in 0..batch.num_rows() {
378 let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
379 self.seen.insert(row_bytes);
380 }
381 }
382 }
383 }
384
385 fn compute_delta(
391 &mut self,
392 candidates: &[RecordBatch],
393 schema: &SchemaRef,
394 ) -> DFResult<Vec<RecordBatch>> {
395 let mut delta_batches = Vec::new();
396 for batch in candidates {
397 if batch.num_rows() == 0 {
398 continue;
399 }
400
401 let arrays: Vec<_> = batch.columns().to_vec();
403 let rows = self.converter.convert_columns(&arrays).map_err(arrow_err)?;
404
405 let mut keep = Vec::with_capacity(batch.num_rows());
407 for row_idx in 0..batch.num_rows() {
408 let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
409 keep.push(self.seen.insert(row_bytes));
410 }
411
412 let keep_mask = arrow_array::BooleanArray::from(keep);
413 let new_cols = batch
414 .columns()
415 .iter()
416 .map(|col| {
417 arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
418 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
419 })
420 })
421 .collect::<DFResult<Vec<_>>>()?;
422
423 if new_cols.first().is_some_and(|c| !c.is_empty()) {
424 let filtered = RecordBatch::try_new(Arc::clone(schema), new_cols).map_err(|e| {
425 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
426 })?;
427 delta_batches.push(filtered);
428 }
429 }
430 Ok(delta_batches)
431 }
432}
433
434pub struct FixpointState {
444 rule_name: String,
445 facts: Vec<RecordBatch>,
446 delta: Vec<RecordBatch>,
447 schema: SchemaRef,
448 key_column_indices: Vec<usize>,
449 key_column_names: Vec<String>,
451 all_column_indices: Vec<usize>,
453 facts_bytes: usize,
455 max_derived_bytes: usize,
457 monotonic_agg: Option<MonotonicAggState>,
459 row_dedup: Option<RowDedupState>,
461 strict_probability_domain: bool,
463 semiring_kind: SemiringKind,
465}
466
467impl FixpointState {
468 pub fn new(
473 rule_name: String,
474 schema: SchemaRef,
475 key_column_indices: Vec<usize>,
476 max_derived_bytes: usize,
477 monotonic_agg: Option<MonotonicAggState>,
478 strict_probability_domain: bool,
479 ) -> Self {
480 Self::new_with_semiring(
481 rule_name,
482 schema,
483 key_column_indices,
484 max_derived_bytes,
485 monotonic_agg,
486 strict_probability_domain,
487 SemiringKind::AddMultProb,
488 )
489 }
490
491 pub fn new_with_semiring(
492 rule_name: String,
493 schema: SchemaRef,
494 key_column_indices: Vec<usize>,
495 max_derived_bytes: usize,
496 monotonic_agg: Option<MonotonicAggState>,
497 strict_probability_domain: bool,
498 semiring_kind: SemiringKind,
499 ) -> Self {
500 let num_cols = schema.fields().len();
501 let row_dedup = RowDedupState::try_new(&schema);
502 let key_column_names: Vec<String> = key_column_indices
503 .iter()
504 .filter_map(|&i| schema.fields().get(i).map(|f| f.name().clone()))
505 .collect();
506 Self {
507 rule_name,
508 facts: Vec::new(),
509 delta: Vec::new(),
510 schema,
511 key_column_indices,
512 key_column_names,
513 all_column_indices: (0..num_cols).collect(),
514 facts_bytes: 0,
515 max_derived_bytes,
516 monotonic_agg,
517 row_dedup,
518 strict_probability_domain,
519 semiring_kind,
520 }
521 }
522
523 fn reconcile_schema(&mut self, actual_schema: &SchemaRef) {
530 if self.schema.fields() != actual_schema.fields() {
531 tracing::debug!(
532 rule = %self.rule_name,
533 "Reconciling fixpoint schema from physical plan output",
534 );
535 self.schema = Arc::clone(actual_schema);
536 self.row_dedup = RowDedupState::try_new(&self.schema);
537 let new_indices: Vec<usize> = self
541 .key_column_names
542 .iter()
543 .filter_map(|name| actual_schema.index_of(name).ok())
544 .collect();
545 if new_indices.len() == self.key_column_names.len() {
546 self.key_column_indices = new_indices;
547 }
548 let num_cols = actual_schema.fields().len();
550 self.all_column_indices = (0..num_cols).collect();
551 }
552 }
553
554 pub async fn merge_delta(
558 &mut self,
559 candidates: Vec<RecordBatch>,
560 task_ctx: Option<Arc<TaskContext>>,
561 ) -> DFResult<bool> {
562 if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
563 self.delta.clear();
564 return Ok(false);
565 }
566
567 if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
571 self.reconcile_schema(&first.schema());
572 }
573
574 let candidates = round_float_columns(&candidates);
576
577 let delta = self.compute_delta(&candidates, task_ctx.as_ref()).await?;
579
580 if delta.is_empty() || delta.iter().all(|b| b.num_rows() == 0) {
581 self.delta.clear();
582 if let Some(ref mut agg) = self.monotonic_agg {
584 agg.snapshot();
585 }
586 return Ok(false);
587 }
588
589 let delta_bytes: usize = delta.iter().map(batch_byte_size).sum();
591 if self.facts_bytes + delta_bytes > self.max_derived_bytes {
592 return Err(datafusion::error::DataFusionError::Execution(
593 LocyRuntimeError::MemoryLimitExceeded {
594 rule: self.rule_name.clone(),
595 bytes: self.facts_bytes + delta_bytes,
596 limit: self.max_derived_bytes,
597 }
598 .to_string(),
599 ));
600 }
601
602 if let Some(ref mut agg) = self.monotonic_agg {
604 agg.snapshot();
605 agg.update(
606 &self.key_column_indices,
607 &delta,
608 self.strict_probability_domain,
609 self.semiring_kind,
610 )?;
611 }
612
613 self.facts_bytes += delta_bytes;
615 self.facts.extend(delta.iter().cloned());
616 self.delta = delta;
617
618 Ok(true)
619 }
620
621 async fn compute_delta(
628 &mut self,
629 candidates: &[RecordBatch],
630 task_ctx: Option<&Arc<TaskContext>>,
631 ) -> DFResult<Vec<RecordBatch>> {
632 let total_existing: usize = self.facts.iter().map(|b| b.num_rows()).sum();
633 if total_existing >= DEDUP_ANTI_JOIN_THRESHOLD
634 && let Some(ctx) = task_ctx
635 {
636 return arrow_left_anti_dedup(candidates.to_vec(), &self.facts, &self.schema, ctx)
637 .await;
638 }
639 if let Some(ref mut rd) = self.row_dedup {
640 rd.compute_delta(candidates, &self.schema)
641 } else {
642 self.compute_delta_legacy(candidates)
643 }
644 }
645
646 fn compute_delta_legacy(&self, candidates: &[RecordBatch]) -> DFResult<Vec<RecordBatch>> {
650 let mut existing: HashSet<Vec<ScalarKey>> = HashSet::new();
652 for batch in &self.facts {
653 for row_idx in 0..batch.num_rows() {
654 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
655 existing.insert(key);
656 }
657 }
658
659 let mut delta_batches = Vec::new();
660 for batch in candidates {
661 if batch.num_rows() == 0 {
662 continue;
663 }
664 let mut keep = Vec::with_capacity(batch.num_rows());
666 for row_idx in 0..batch.num_rows() {
667 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
668 keep.push(!existing.contains(&key));
669 }
670
671 for (row_idx, kept) in keep.iter_mut().enumerate() {
673 if *kept {
674 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
675 if !existing.insert(key) {
676 *kept = false;
677 }
678 }
679 }
680
681 let keep_mask = arrow_array::BooleanArray::from(keep);
682 let new_rows = batch
683 .columns()
684 .iter()
685 .map(|col| {
686 arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
687 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
688 })
689 })
690 .collect::<DFResult<Vec<_>>>()?;
691
692 if new_rows.first().is_some_and(|c| !c.is_empty()) {
693 let filtered =
694 RecordBatch::try_new(Arc::clone(&self.schema), new_rows).map_err(|e| {
695 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
696 })?;
697 delta_batches.push(filtered);
698 }
699 }
700
701 Ok(delta_batches)
702 }
703
704 pub fn is_converged(&self) -> bool {
706 let delta_empty = self.delta.is_empty() || self.delta.iter().all(|b| b.num_rows() == 0);
707 let agg_stable = self.monotonic_agg.as_ref().is_none_or(|a| a.is_stable());
708 delta_empty && agg_stable
709 }
710
711 pub fn all_facts(&self) -> &[RecordBatch] {
713 &self.facts
714 }
715
716 pub fn all_delta(&self) -> &[RecordBatch] {
718 &self.delta
719 }
720
721 pub fn into_facts(self) -> Vec<RecordBatch> {
723 self.facts
724 }
725
726 pub fn merge_best_by(
737 &mut self,
738 candidates: Vec<RecordBatch>,
739 sort_criteria: &[SortCriterion],
740 ) -> DFResult<bool> {
741 if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
742 self.delta.clear();
743 return Ok(false);
744 }
745
746 if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
748 self.reconcile_schema(&first.schema());
749 }
750
751 let candidates = round_float_columns(&candidates);
753
754 let old_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> =
756 self.build_key_criteria_map(sort_criteria);
757
758 let mut all_batches = self.facts.clone();
760 all_batches.extend(candidates);
761 let all_batches: Vec<_> = all_batches
762 .into_iter()
763 .filter(|b| b.num_rows() > 0)
764 .collect();
765 if all_batches.is_empty() {
766 self.delta.clear();
767 return Ok(false);
768 }
769
770 let combined = arrow::compute::concat_batches(&self.schema, &all_batches)
771 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
772
773 if combined.num_rows() == 0 {
774 self.delta.clear();
775 return Ok(false);
776 }
777
778 let mut sort_columns = Vec::new();
781 for &ki in &self.key_column_indices {
782 if ki >= combined.num_columns() {
783 continue;
784 }
785 sort_columns.push(arrow::compute::SortColumn {
786 values: Arc::clone(combined.column(ki)),
787 options: Some(arrow::compute::SortOptions {
788 descending: false,
789 nulls_first: false,
790 }),
791 });
792 }
793 for criterion in sort_criteria {
794 if criterion.col_index >= combined.num_columns() {
795 continue;
796 }
797 sort_columns.push(arrow::compute::SortColumn {
798 values: Arc::clone(combined.column(criterion.col_index)),
799 options: Some(arrow::compute::SortOptions {
800 descending: !criterion.ascending,
801 nulls_first: criterion.nulls_first,
802 }),
803 });
804 }
805
806 let sorted_indices =
807 arrow::compute::lexsort_to_indices(&sort_columns, None).map_err(arrow_err)?;
808 let sorted_columns: Vec<_> = combined
809 .columns()
810 .iter()
811 .map(|col| arrow::compute::take(col.as_ref(), &sorted_indices, None))
812 .collect::<Result<Vec<_>, _>>()
813 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
814 let sorted = RecordBatch::try_new(Arc::clone(&self.schema), sorted_columns)
815 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
816
817 let mut keep_indices: Vec<u32> = Vec::new();
819 let mut prev_key: Option<Vec<ScalarKey>> = None;
820 for row_idx in 0..sorted.num_rows() {
821 let key = extract_scalar_key(&sorted, &self.key_column_indices, row_idx);
822 let is_new_group = match &prev_key {
823 None => true,
824 Some(prev) => *prev != key,
825 };
826 if is_new_group {
827 keep_indices.push(row_idx as u32);
828 prev_key = Some(key);
829 }
830 }
831
832 let keep_array = arrow_array::UInt32Array::from(keep_indices);
833 let output_columns: Vec<_> = sorted
834 .columns()
835 .iter()
836 .map(|col| arrow::compute::take(col.as_ref(), &keep_array, None))
837 .collect::<Result<Vec<_>, _>>()
838 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
839 let pruned = RecordBatch::try_new(Arc::clone(&self.schema), output_columns)
840 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
841
842 let new_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> = {
844 let mut map = HashMap::new();
845 for row_idx in 0..pruned.num_rows() {
846 let key = extract_scalar_key(&pruned, &self.key_column_indices, row_idx);
847 let criteria: Vec<ScalarKey> = sort_criteria
848 .iter()
849 .flat_map(|c| extract_scalar_key(&pruned, &[c.col_index], row_idx))
850 .collect();
851 map.insert(key, criteria);
852 }
853 map
854 };
855 let changed = old_best != new_best;
856
857 tracing::debug!(
858 rule = %self.rule_name,
859 old_keys = old_best.len(),
860 new_keys = new_best.len(),
861 changed = changed,
862 "BEST BY merge"
863 );
864
865 self.facts_bytes = batch_byte_size(&pruned);
867 self.facts = vec![pruned];
868 if changed {
869 self.delta = self.facts.clone();
872 } else {
873 self.delta.clear();
874 }
875
876 self.row_dedup = RowDedupState::try_new(&self.schema);
878 if let Some(ref mut rd) = self.row_dedup {
879 rd.ingest_existing(&self.facts, &self.schema);
880 }
881
882 Ok(changed)
883 }
884
885 fn build_key_criteria_map(
887 &self,
888 sort_criteria: &[SortCriterion],
889 ) -> HashMap<Vec<ScalarKey>, Vec<ScalarKey>> {
890 let mut map = HashMap::new();
891 for batch in &self.facts {
892 for row_idx in 0..batch.num_rows() {
893 let key = extract_scalar_key(batch, &self.key_column_indices, row_idx);
894 let criteria: Vec<ScalarKey> = sort_criteria
895 .iter()
896 .flat_map(|c| extract_scalar_key(batch, &[c.col_index], row_idx))
897 .collect();
898 map.insert(key, criteria);
899 }
900 }
901 map
902 }
903}
904
905fn batch_byte_size(batch: &RecordBatch) -> usize {
907 batch
908 .columns()
909 .iter()
910 .map(|col| col.get_buffer_memory_size())
911 .sum()
912}
913
914fn round_float_columns(batches: &[RecordBatch]) -> Vec<RecordBatch> {
920 batches
921 .iter()
922 .map(|batch| {
923 let schema = batch.schema();
924 let has_float = schema
925 .fields()
926 .iter()
927 .any(|f| *f.data_type() == arrow_schema::DataType::Float64);
928 if !has_float {
929 return batch.clone();
930 }
931
932 let columns: Vec<arrow_array::ArrayRef> = batch
933 .columns()
934 .iter()
935 .enumerate()
936 .map(|(i, col)| {
937 if *schema.field(i).data_type() == arrow_schema::DataType::Float64 {
938 let arr = col
939 .as_any()
940 .downcast_ref::<arrow_array::Float64Array>()
941 .unwrap();
942 let rounded: arrow_array::Float64Array = arr
943 .iter()
944 .map(|v| v.map(|f| (f * 1e12).round() / 1e12))
945 .collect();
946 Arc::new(rounded) as arrow_array::ArrayRef
947 } else {
948 Arc::clone(col)
949 }
950 })
951 .collect();
952
953 RecordBatch::try_new(schema, columns).unwrap_or_else(|_| batch.clone())
954 })
955 .collect()
956}
957
958const DEDUP_ANTI_JOIN_THRESHOLD: usize = 300;
968
969fn dedup_batches_all_columns(
984 batches: Vec<RecordBatch>,
985 schema: &SchemaRef,
986) -> DFResult<Vec<RecordBatch>> {
987 let fields: Vec<SortField> = schema
988 .fields()
989 .iter()
990 .map(|f| SortField::new(f.data_type().clone()))
991 .collect();
992 let Ok(converter) = RowConverter::new(fields) else {
996 return Ok(batches);
997 };
998 let mut seen: HashSet<Box<[u8]>> = HashSet::new();
999 let mut out = Vec::with_capacity(batches.len());
1000 for batch in batches {
1001 if batch.num_rows() == 0 {
1002 continue;
1003 }
1004 let rows = converter
1005 .convert_columns(batch.columns())
1006 .map_err(arrow_err)?;
1007 let mut keep = Vec::with_capacity(batch.num_rows());
1008 for row_idx in 0..batch.num_rows() {
1009 let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
1010 keep.push(seen.insert(row_bytes));
1011 }
1012 let keep_mask = arrow_array::BooleanArray::from(keep);
1013 let cols = batch
1014 .columns()
1015 .iter()
1016 .map(|c| arrow::compute::filter(c.as_ref(), &keep_mask).map_err(arrow_err))
1017 .collect::<DFResult<Vec<_>>>()?;
1018 if cols.first().is_some_and(|c| !c.is_empty()) {
1019 out.push(RecordBatch::try_new(Arc::clone(schema), cols).map_err(arrow_err)?);
1020 }
1021 }
1022 Ok(out)
1023}
1024
1025async fn arrow_left_anti_dedup(
1026 candidates: Vec<RecordBatch>,
1027 existing: &[RecordBatch],
1028 schema: &SchemaRef,
1029 task_ctx: &Arc<TaskContext>,
1030) -> DFResult<Vec<RecordBatch>> {
1031 if existing.is_empty() || existing.iter().all(|b| b.num_rows() == 0) {
1032 return dedup_batches_all_columns(candidates, schema);
1035 }
1036
1037 let left: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(candidates, Arc::clone(schema)));
1038 let right: Arc<dyn ExecutionPlan> =
1039 Arc::new(InMemoryExec::new(existing.to_vec(), Arc::clone(schema)));
1040
1041 let on: Vec<(
1042 Arc<dyn datafusion::physical_plan::PhysicalExpr>,
1043 Arc<dyn datafusion::physical_plan::PhysicalExpr>,
1044 )> = schema
1045 .fields()
1046 .iter()
1047 .enumerate()
1048 .map(|(i, field)| {
1049 let l: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
1050 datafusion::physical_plan::expressions::Column::new(field.name(), i),
1051 );
1052 let r: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
1053 datafusion::physical_plan::expressions::Column::new(field.name(), i),
1054 );
1055 (l, r)
1056 })
1057 .collect();
1058
1059 if on.is_empty() {
1060 return Ok(vec![]);
1061 }
1062
1063 let join = HashJoinExec::try_new(
1064 left,
1065 right,
1066 on,
1067 None,
1068 &JoinType::LeftAnti,
1069 None,
1070 PartitionMode::CollectLeft,
1071 datafusion::common::NullEquality::NullEqualsNull,
1072 false,
1077 )?;
1078
1079 let join_arc: Arc<dyn ExecutionPlan> = Arc::new(join);
1080 let anti = collect_all_partitions(&join_arc, task_ctx.clone()).await?;
1083 dedup_batches_all_columns(anti, schema)
1084}
1085
1086#[derive(Debug, Clone)]
1092pub struct IsRefBinding {
1093 pub derived_scan_index: usize,
1095 pub rule_name: String,
1097 pub is_self_ref: bool,
1099 pub negated: bool,
1101 pub anti_join_cols: Vec<(String, String)>,
1107 pub target_has_prob: bool,
1109 pub target_prob_col: Option<String>,
1111 pub provenance_join_cols: Vec<(String, String)>,
1116}
1117
1118#[derive(Debug)]
1120pub struct FixpointClausePlan {
1121 pub body_logical: LogicalPlan,
1123 pub is_ref_bindings: Vec<IsRefBinding>,
1125 pub priority: Option<i64>,
1127 pub along_bindings: Vec<String>,
1129 pub model_invocations: Vec<ModelInvocation>,
1133}
1134
1135#[derive(Debug)]
1137pub struct FixpointRulePlan {
1138 pub name: String,
1140 pub clauses: Vec<FixpointClausePlan>,
1142 pub yield_schema: SchemaRef,
1144 pub key_column_indices: Vec<usize>,
1146 pub priority: Option<i64>,
1148 pub has_fold: bool,
1150 pub fold_bindings: Vec<FoldBinding>,
1152 pub having: Vec<Expr>,
1154 pub has_best_by: bool,
1156 pub best_by_criteria: Vec<SortCriterion>,
1158 pub has_priority: bool,
1160 pub deterministic: bool,
1164 pub prob_column_name: Option<String>,
1166 pub non_linear: bool,
1173}
1174
1175#[expect(clippy::too_many_arguments, reason = "Fixpoint loop needs all context")]
1184async fn run_fixpoint_loop(
1185 rules: Vec<FixpointRulePlan>,
1186 max_iterations: usize,
1187 timeout: Duration,
1188 graph_ctx: Arc<GraphExecutionContext>,
1189 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
1190 storage: Arc<StorageManager>,
1191 schema_info: Arc<UniSchema>,
1192 params: HashMap<String, Value>,
1193 registry: Arc<DerivedScanRegistry>,
1194 output_schema: SchemaRef,
1195 max_derived_bytes: usize,
1196 derivation_tracker: Option<Arc<ProvenanceStore>>,
1197 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
1198 strict_probability_domain: bool,
1199 probability_epsilon: f64,
1200 exact_probability: bool,
1201 max_bdd_variables: usize,
1202 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
1203 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
1204 top_k_proofs: usize,
1205 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
1206 semiring_kind: SemiringKind,
1207 classifier_registry: Arc<ClassifierRegistry>,
1208 classifier_cache: Option<Arc<ModelInvocationCache>>,
1209 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
1210 profile_collector: Option<Arc<LocyProfileCollector>>,
1211) -> DFResult<Vec<RecordBatch>> {
1212 let start = Instant::now();
1213 let task_ctx = session_ctx.read().task_ctx();
1214
1215 if semiring_kind == SemiringKind::MaxMinProb {
1220 let mut warnings = warnings_slot.write().unwrap_or_else(|e| e.into_inner());
1221 let mut already_warned: HashSet<String> = warnings
1222 .iter()
1223 .filter(|w| w.code == RuntimeWarningCode::FuzzyNotProbabilistic)
1224 .map(|w| w.rule_name.clone())
1225 .collect();
1226 for rule in &rules {
1227 if rule.prob_column_name.is_some() && !already_warned.contains(&rule.name) {
1228 warnings.push(RuntimeWarning {
1229 code: RuntimeWarningCode::FuzzyNotProbabilistic,
1230 message: format!(
1231 "rule '{}' carries a PROB column but is being evaluated under \
1232 the MaxMinProb (fuzzy / Viterbi) semiring; outputs are fuzzy \
1233 truth values, not probabilities",
1234 rule.name
1235 ),
1236 rule_name: rule.name.clone(),
1237 variable_count: None,
1238 key_group: None,
1239 });
1240 already_warned.insert(rule.name.clone());
1241 }
1242 }
1243 }
1244
1245 let mut states: Vec<FixpointState> = rules
1247 .iter()
1248 .map(|rule| {
1249 let monotonic_agg = if !rule.fold_bindings.is_empty() {
1250 let bindings: Vec<MonotonicFoldBinding> = rule
1251 .fold_bindings
1252 .iter()
1253 .map(|fb| MonotonicFoldBinding {
1254 fold_name: fb.output_name.clone(),
1255 aggregate: std::sync::Arc::clone(&fb.aggregate),
1256 input_col_index: fb.input_col_index,
1257 input_col_name: fb.input_col_name.clone(),
1258 })
1259 .collect();
1260 Some(MonotonicAggState::new(bindings))
1261 } else {
1262 None
1263 };
1264 FixpointState::new_with_semiring(
1265 rule.name.clone(),
1266 Arc::clone(&rule.yield_schema),
1267 rule.key_column_indices.clone(),
1268 max_derived_bytes,
1269 monotonic_agg,
1270 strict_probability_domain,
1271 semiring_kind,
1272 )
1273 })
1274 .collect();
1275
1276 let mut converged = false;
1278 let mut total_iters = 0usize;
1279 for iteration in 0..max_iterations {
1280 total_iters = iteration + 1;
1281 tracing::debug!("fixpoint iteration {}", iteration);
1282 let mut any_changed = false;
1283
1284 for rule_idx in 0..rules.len() {
1285 let rule = &rules[rule_idx];
1286
1287 update_derived_scan_handles(®istry, &states, rule_idx, &rules);
1289
1290 let mut all_candidates = Vec::new();
1292 let mut clause_candidates: Vec<Vec<RecordBatch>> = Vec::new();
1293 let rule_start = Instant::now();
1296 let mut iter_ops: Vec<OperatorStats> = Vec::new();
1297 for clause in &rule.clauses {
1298 let mut batches = if profile_collector.is_some() {
1304 let (b, ops) = execute_subplan_collecting(
1305 &clause.body_logical,
1306 ¶ms,
1307 &HashMap::new(),
1308 &graph_ctx,
1309 &session_ctx,
1310 &storage,
1311 &schema_info,
1312 None, )
1314 .await?;
1315 iter_ops.extend(ops);
1316 b
1317 } else {
1318 execute_subplan(
1319 &clause.body_logical,
1320 ¶ms,
1321 &HashMap::new(),
1322 &graph_ctx,
1323 &session_ctx,
1324 &storage,
1325 &schema_info,
1326 None, )
1328 .await?
1329 };
1330 for binding in &clause.is_ref_bindings {
1332 if binding.negated
1333 && !binding.anti_join_cols.is_empty()
1334 && let Some(entry) = registry.get(binding.derived_scan_index)
1335 {
1336 let neg_facts = entry.data.read().clone();
1337 if !neg_facts.is_empty() {
1338 if binding.target_has_prob && rule.prob_column_name.is_some() {
1339 let complement_col =
1341 format!("__prob_complement_{}", binding.rule_name);
1342 if let Some(prob_col) = &binding.target_prob_col {
1343 batches = apply_prob_complement_composite(
1344 batches,
1345 &neg_facts,
1346 &binding.anti_join_cols,
1347 prob_col,
1348 &complement_col,
1349 )?;
1350 } else {
1351 batches = apply_anti_join_composite(
1353 batches,
1354 &neg_facts,
1355 &binding.anti_join_cols,
1356 )?;
1357 }
1358 } else {
1359 batches = apply_anti_join_composite(
1361 batches,
1362 &neg_facts,
1363 &binding.anti_join_cols,
1364 )?;
1365 }
1366 }
1367 }
1368 }
1369 let complement_cols: Vec<String> = if !batches.is_empty() {
1371 batches[0]
1372 .schema()
1373 .fields()
1374 .iter()
1375 .filter(|f| f.name().starts_with("__prob_complement_"))
1376 .map(|f| f.name().clone())
1377 .collect()
1378 } else {
1379 vec![]
1380 };
1381 if !complement_cols.is_empty() {
1382 batches = multiply_prob_factors(
1383 batches,
1384 rule.prob_column_name.as_deref(),
1385 &complement_cols,
1386 )?;
1387 }
1388
1389 clause_candidates.push(batches.clone());
1390 all_candidates.extend(batches);
1391 }
1392
1393 let changed = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
1397 states[rule_idx].merge_best_by(all_candidates, &rule.best_by_criteria)?
1398 } else {
1399 states[rule_idx]
1400 .merge_delta(all_candidates, Some(Arc::clone(&task_ctx)))
1401 .await?
1402 };
1403 if changed {
1404 any_changed = true;
1405 if let Some(ref tracker) = derivation_tracker {
1407 record_provenance(
1408 ProvenanceCtx {
1409 tracker,
1410 registry: ®istry,
1411 warnings_slot: &warnings_slot,
1412 },
1413 rule,
1414 &states[rule_idx],
1415 &clause_candidates,
1416 iteration,
1417 top_k_proofs,
1418 ClassifierRefs {
1419 registry: &classifier_registry,
1420 cache: classifier_cache.as_ref(),
1421 provenance_store: classifier_provenance_store.as_ref(),
1422 },
1423 )
1424 .await;
1425 }
1426 }
1427
1428 if let Some(ref collector) = profile_collector {
1431 let delta_facts: usize = states[rule_idx]
1432 .all_delta()
1433 .iter()
1434 .map(|b| b.num_rows())
1435 .sum();
1436 collector.record(
1437 &rule.name,
1438 iteration,
1439 delta_facts,
1440 rule_start.elapsed().as_secs_f64() * 1000.0,
1441 iter_ops,
1442 );
1443 }
1444 }
1445
1446 if !any_changed && states.iter().all(|s| s.is_converged()) {
1448 tracing::debug!("fixpoint converged after {} iterations", iteration + 1);
1449 converged = true;
1450 break;
1451 }
1452
1453 if start.elapsed() > timeout {
1455 tracing::warn!(
1456 "fixpoint timeout after {} iterations; returning partial results",
1457 iteration + 1,
1458 );
1459 interruption::set(&timeout_flag, interruption::TIMEOUT);
1460 break;
1461 }
1462 }
1463
1464 if let Ok(mut counts) = iteration_counts.write() {
1466 for rule in &rules {
1467 counts.insert(rule.name.clone(), total_iters);
1468 }
1469 }
1470
1471 if let Some(ref collector) = profile_collector {
1473 for (idx, rule) in rules.iter().enumerate() {
1474 let facts: usize = states[idx].all_facts().iter().map(|b| b.num_rows()).sum();
1475 collector.set_final_facts(&rule.name, facts);
1476 }
1477 }
1478
1479 if !converged && interruption::reason(&timeout_flag).is_none() {
1484 tracing::warn!(
1485 "fixpoint did not converge after {max_iterations} iterations; returning partial results",
1486 );
1487 interruption::set(&timeout_flag, interruption::ITERATION_LIMIT);
1488 }
1489
1490 let task_ctx = session_ctx.read().task_ctx();
1492 let mut all_output = Vec::new();
1493
1494 for (rule_idx, state) in states.into_iter().enumerate() {
1495 let rule = &rules[rule_idx];
1496 let mut facts = state.into_facts();
1497 if facts.is_empty() {
1498 continue;
1499 }
1500
1501 let shared_info = if semiring_kind == SemiringKind::MaxMinProb {
1519 None
1520 } else if let Some(ref tracker) = derivation_tracker {
1521 detect_shared_lineage(rule, &facts, tracker, &warnings_slot, semiring_kind)
1522 } else {
1523 None
1524 };
1525
1526 if exact_probability
1528 && let Some(ref info) = shared_info
1529 && let Some(ref tracker) = derivation_tracker
1530 {
1531 facts = apply_exact_wmc(
1532 facts,
1533 rule,
1534 info,
1535 tracker,
1536 max_bdd_variables,
1537 &warnings_slot,
1538 &approximate_slot,
1539 )?;
1540 }
1541
1542 let processed = apply_post_fixpoint_chain(
1543 facts,
1544 rule,
1545 &task_ctx,
1546 strict_probability_domain,
1547 probability_epsilon,
1548 semiring_kind,
1549 derivation_tracker.as_ref().map(Arc::clone),
1550 top_k_proofs,
1551 Some(Arc::clone(®istry)),
1552 )
1553 .await?;
1554 all_output.extend(processed);
1555 }
1556
1557 if all_output.is_empty() {
1559 all_output.push(RecordBatch::new_empty(output_schema));
1560 }
1561
1562 Ok(all_output)
1563}
1564
1565pub(crate) struct ClassifierRefs<'a> {
1577 pub registry: &'a Arc<ClassifierRegistry>,
1578 pub cache: Option<&'a Arc<uni_locy::ModelInvocationCache>>,
1579 pub provenance_store: Option<&'a Arc<uni_locy::NeuralProvenanceStore>>,
1586}
1587
1588pub(crate) struct ProvenanceCtx<'a> {
1594 pub tracker: &'a Arc<ProvenanceStore>,
1595 pub registry: &'a Arc<DerivedScanRegistry>,
1596 pub warnings_slot: &'a Arc<StdRwLock<Vec<RuntimeWarning>>>,
1597}
1598
1599async fn record_provenance(
1600 prov: ProvenanceCtx<'_>,
1601 rule: &FixpointRulePlan,
1602 state: &FixpointState,
1603 clause_candidates: &[Vec<RecordBatch>],
1604 iteration: usize,
1605 top_k_proofs: usize,
1606 classifiers: ClassifierRefs<'_>,
1607) {
1608 let tracker = prov.tracker;
1609 let registry = prov.registry;
1610 let warnings_slot = prov.warnings_slot;
1611 let classifier_registry = classifiers.registry;
1612 let classifier_cache = classifiers.cache;
1613 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1614
1615 let base_probs = if top_k_proofs > 0 {
1617 tracker.base_fact_probs()
1618 } else {
1619 HashMap::new()
1620 };
1621
1622 let mut topk_acc = TopKProofAccumulator::new();
1623
1624 for delta_batch in state.all_delta() {
1625 for row_idx in 0..delta_batch.num_rows() {
1626 let row_hash = format!(
1627 "{:?}",
1628 extract_scalar_key(delta_batch, &all_indices, row_idx)
1629 )
1630 .into_bytes();
1631 let fact_row = batch_row_to_value_map(delta_batch, row_idx);
1632 let clause_index =
1633 find_clause_for_row(delta_batch, row_idx, &all_indices, clause_candidates);
1634
1635 let support = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
1636
1637 let proof_probability = if top_k_proofs > 0 {
1638 compute_proof_probability(&support, &base_probs)
1639 } else {
1640 None
1641 };
1642
1643 let entry = ProvenanceAnnotation {
1644 rule_name: rule.name.clone(),
1645 clause_index,
1646 support,
1647 along_values: {
1648 let along_names: Vec<String> = rule
1649 .clauses
1650 .get(clause_index)
1651 .map(|c| c.along_bindings.clone())
1652 .unwrap_or_default();
1653 along_names
1654 .iter()
1655 .filter_map(|name| fact_row.get(name).map(|v| (name.clone(), v.clone())))
1656 .collect()
1657 },
1658 iteration,
1659 fact_row: fact_row.clone(),
1660 proof_probability,
1661 neural_calls: collect_neural_calls_for_row(
1662 rule,
1663 clause_index,
1664 &fact_row,
1665 classifier_registry,
1666 classifier_cache,
1667 classifiers.provenance_store,
1668 )
1669 .await,
1670 };
1671 if top_k_proofs > 0 {
1672 topk_acc.accumulate(&entry, &row_hash);
1673 tracker.record_top_k(row_hash, entry, top_k_proofs);
1674 } else {
1675 tracker.record(row_hash, entry);
1676 }
1677 }
1678 }
1679
1680 topk_acc.emit_warning_if_any(rule, top_k_proofs, warnings_slot);
1681}
1682
1683struct TopKProofAccumulator {
1690 per_fact: HashMap<Vec<u8>, Vec<uni_locy::Proof>>,
1691 base_rv_interner: HashMap<Vec<u8>, uni_locy::BaseRv>,
1692 next_rv: u32,
1693}
1694
1695impl TopKProofAccumulator {
1696 fn new() -> Self {
1697 Self {
1698 per_fact: HashMap::new(),
1699 base_rv_interner: HashMap::new(),
1700 next_rv: 0,
1701 }
1702 }
1703
1704 fn accumulate(&mut self, entry: &ProvenanceAnnotation, row_hash: &[u8]) {
1705 let mut base_rvs = uni_locy::BaseRvSet::empty();
1706 for term in &entry.support {
1707 let rv = *self
1708 .base_rv_interner
1709 .entry(term.base_fact_id.clone())
1710 .or_insert_with(|| {
1711 let r = uni_locy::BaseRv(self.next_rv);
1712 self.next_rv += 1;
1713 r
1714 });
1715 base_rvs.insert(rv);
1716 }
1717 self.per_fact
1718 .entry(row_hash.to_vec())
1719 .or_default()
1720 .push(uni_locy::Proof {
1721 weight: entry.proof_probability.unwrap_or(0.0),
1722 base_rvs,
1723 neural_calls: Vec::new(),
1724 });
1725 }
1726
1727 fn emit_warning_if_any(
1728 &self,
1729 rule: &FixpointRulePlan,
1730 top_k_proofs: usize,
1731 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1732 ) {
1733 if top_k_proofs == 0 || self.per_fact.is_empty() {
1734 return;
1735 }
1736 let crossed_facts = self
1737 .per_fact
1738 .values()
1739 .filter(|proofs| {
1740 let (_kept, notice) =
1741 uni_locy::merge_top_k_runtime(Vec::new(), (*proofs).clone(), top_k_proofs);
1742 notice == uni_locy::PruneNotice::CrossedDependency
1743 })
1744 .count();
1745 if crossed_facts == 0 {
1746 return;
1747 }
1748 let Ok(mut w) = warnings_slot.write() else {
1749 return;
1750 };
1751 let already = w.iter().any(|rw| {
1752 matches!(
1753 rw.code,
1754 uni_locy::types::RuntimeWarningCode::TopKPruningCrossedDependency
1755 ) && rw.rule_name == rule.name
1756 });
1757 if already {
1758 return;
1759 }
1760 w.push(RuntimeWarning {
1761 code: uni_locy::types::RuntimeWarningCode::TopKPruningCrossedDependency,
1762 rule_name: rule.name.clone(),
1763 message: format!(
1764 "rule '{}': top-K proof pruning (k={}) discarded {} fact(s) \
1765 whose dependencies overlap retained proofs. The retained \
1766 top-{} under-counts the true joint probability for those \
1767 facts (Scallop, Huang et al. 2021). Increase k to recover.",
1768 rule.name, top_k_proofs, crossed_facts, top_k_proofs
1769 ),
1770 variable_count: None,
1771 key_group: None,
1772 });
1773 }
1774}
1775
1776async fn collect_neural_calls_for_row(
1798 rule: &FixpointRulePlan,
1799 clause_index: usize,
1800 fact_row: &uni_locy::FactRow,
1801 classifier_registry: &Arc<ClassifierRegistry>,
1802 classifier_cache: Option<&Arc<uni_locy::ModelInvocationCache>>,
1803 provenance_store: Option<&Arc<uni_locy::NeuralProvenanceStore>>,
1804) -> Vec<uni_locy::NeuralProvenance> {
1805 let Some(clause) = rule.clauses.get(clause_index) else {
1806 return Vec::new();
1807 };
1808 if clause.model_invocations.is_empty() {
1809 return Vec::new();
1810 }
1811 let mut out = Vec::with_capacity(clause.model_invocations.len());
1812 for invocation in &clause.model_invocations {
1813 let mut features = std::collections::HashMap::new();
1824 for (binding_name, feat_expr) in invocation
1825 .feature_names
1826 .iter()
1827 .zip(invocation.feature_exprs.iter())
1828 {
1829 features.insert(
1830 binding_name.clone(),
1831 eval_feature_expr_against_fact_row(feat_expr, fact_row),
1832 );
1833 }
1834 let input = uni_locy::ClassifyInput { features };
1835 let input_hash = input.stable_hash();
1836
1837 if let Some(store) = provenance_store
1844 && let Some(record) = store.get(&invocation.model_name, input_hash)
1845 {
1846 out.push(uni_locy::NeuralProvenance {
1847 model_name: invocation.model_name.clone(),
1848 raw_probability: record.raw_probability,
1849 calibrated_probability: record.calibrated_probability,
1850 confidence_band: record.confidence_band,
1851 });
1852 continue;
1853 }
1854
1855 let Some(classifier) = classifier_registry.get(&invocation.model_name) else {
1860 continue;
1861 };
1862 let raw = if let Some(v) =
1863 classifier_cache.and_then(|c| c.get(&invocation.model_name, input_hash))
1864 {
1865 v
1866 } else {
1867 match classifier.classify(std::slice::from_ref(&input)).await {
1868 Ok(probs) => {
1869 let v = probs.first().copied().unwrap_or(0.0);
1870 if let Some(c) = classifier_cache {
1871 c.insert(&invocation.model_name, input_hash, v);
1872 }
1873 v
1874 }
1875 Err(_) => continue,
1876 }
1877 };
1878 let calibrator = classifier.get_calibrator();
1879 let calibrated_probability = calibrator.as_ref().map(|_| raw);
1880 let confidence_band = calibrator.as_ref().and_then(|c| c.confidence_band(raw));
1881 out.push(uni_locy::NeuralProvenance {
1882 model_name: invocation.model_name.clone(),
1883 raw_probability: raw,
1884 calibrated_probability,
1885 confidence_band,
1886 });
1887 }
1888 out
1889}
1890
1891fn eval_feature_expr_against_fact_row(
1899 expr: &uni_cypher::ast::Expr,
1900 fact_row: &uni_locy::FactRow,
1901) -> uni_locy::FeatureValue {
1902 use uni_cypher::ast::Expr;
1903 use uni_locy::FeatureValue;
1904 let value_to_feature = |v: Option<&uni_common::Value>| -> FeatureValue {
1905 match v {
1906 Some(uni_common::Value::Float(f)) => FeatureValue::Float(*f),
1907 Some(uni_common::Value::Int(i)) => FeatureValue::Int(*i),
1908 Some(uni_common::Value::Bool(b)) => FeatureValue::Bool(*b),
1909 Some(uni_common::Value::String(s)) => FeatureValue::String(s.clone()),
1910 Some(uni_common::Value::Node(n)) => {
1911 FeatureValue::Int(n.vid.as_u64() as i64)
1913 }
1914 _ => FeatureValue::Null,
1915 }
1916 };
1917 let resolve_value = |sub: &Expr| -> uni_common::Value {
1921 match sub {
1922 Expr::Variable(name) => fact_row
1923 .get(name)
1924 .cloned()
1925 .unwrap_or(uni_common::Value::Null),
1926 Expr::Property(boxed, prop) if matches!(boxed.as_ref(), Expr::Variable(_)) => {
1927 let Expr::Variable(v) = boxed.as_ref() else {
1928 unreachable!()
1929 };
1930 let key = format!("{}.{}", v, prop);
1931 if let Some(val) = fact_row.get(&key) {
1932 return val.clone();
1933 }
1934 if let Some(uni_common::Value::Node(n)) = fact_row.get(v) {
1935 return n
1936 .properties
1937 .get(prop)
1938 .cloned()
1939 .unwrap_or(uni_common::Value::Null);
1940 }
1941 uni_common::Value::Null
1942 }
1943 Expr::Literal(lit) => lit.to_value(),
1944 Expr::List(items) => {
1945 let mut out = Vec::with_capacity(items.len());
1946 for it in items {
1947 out.push(match it {
1948 Expr::Literal(lit) => lit.to_value(),
1949 _ => uni_common::Value::Null,
1950 });
1951 }
1952 uni_common::Value::List(out)
1953 }
1954 _ => uni_common::Value::Null,
1955 }
1956 };
1957
1958 match expr {
1959 Expr::Variable(name) => value_to_feature(fact_row.get(name)),
1960 Expr::Property(boxed, prop) => {
1961 if let Expr::Variable(v) = boxed.as_ref() {
1962 let key = format!("{}.{}", v, prop);
1964 if let Some(val) = fact_row.get(&key) {
1965 return value_to_feature(Some(val));
1966 }
1967 let hidden_key = format!("__feat_{}_{}", v, prop);
1977 if let Some(val) = fact_row.get(&hidden_key) {
1978 return value_to_feature(Some(val));
1979 }
1980 if let Some(uni_common::Value::Node(n)) = fact_row.get(v) {
1984 return value_to_feature(n.properties.get(prop));
1985 }
1986 }
1987 FeatureValue::Null
1988 }
1989 Expr::FunctionCall { name, args, .. } if name == "similar_to" && args.len() == 2 => {
1990 let lv = resolve_value(&args[0]);
1991 let rv = resolve_value(&args[1]);
1992 match crate::query::similar_to::eval_similar_to_pure(&lv, &rv) {
1993 Ok(uni_common::Value::Float(f)) => FeatureValue::Float(f),
1994 _ => FeatureValue::Null,
1995 }
1996 }
1997 Expr::FunctionCall { name, .. }
2012 if matches!(
2013 name.as_str(),
2014 "degree_centrality"
2015 | "pagerank_score"
2016 | "closeness_centrality"
2017 | "betweenness_centrality"
2018 | "eigenvector_centrality"
2019 | "harmonic_centrality"
2020 | "katz_centrality"
2021 | "avg_neighbor"
2022 | "max_neighbor"
2023 | "sum_neighbor"
2024 ) =>
2025 {
2026 FeatureValue::Null
2027 }
2028 _ => FeatureValue::Null,
2029 }
2030}
2031
2032fn collect_is_ref_inputs(
2033 rule: &FixpointRulePlan,
2034 clause_index: usize,
2035 delta_batch: &RecordBatch,
2036 row_idx: usize,
2037 registry: &Arc<DerivedScanRegistry>,
2038) -> Vec<ProofTerm> {
2039 let clause = match rule.clauses.get(clause_index) {
2040 Some(c) => c,
2041 None => return vec![],
2042 };
2043
2044 let mut inputs = Vec::new();
2045 let delta_schema = delta_batch.schema();
2046
2047 for binding in &clause.is_ref_bindings {
2048 if binding.negated {
2049 continue;
2050 }
2051 if binding.provenance_join_cols.is_empty() {
2052 continue;
2053 }
2054
2055 let body_values: Vec<(String, ScalarKey)> = binding
2057 .provenance_join_cols
2058 .iter()
2059 .filter_map(|(body_col, _derived_col)| {
2060 let col_idx = delta_schema
2061 .fields()
2062 .iter()
2063 .position(|f| f.name() == body_col)?;
2064 let key = extract_scalar_key(delta_batch, &[col_idx], row_idx);
2065 Some((body_col.clone(), key.into_iter().next()?))
2066 })
2067 .collect();
2068
2069 if body_values.len() != binding.provenance_join_cols.len() {
2070 continue;
2071 }
2072
2073 let entry = match registry.get(binding.derived_scan_index) {
2075 Some(e) => e,
2076 None => continue,
2077 };
2078 let source_batches = entry.data.read();
2079 let source_schema = &entry.schema;
2080
2081 for src_batch in source_batches.iter() {
2083 let all_src_indices: Vec<usize> = (0..src_batch.num_columns()).collect();
2084 for src_row in 0..src_batch.num_rows() {
2085 let matches = binding.provenance_join_cols.iter().enumerate().all(
2086 |(i, (_body_col, derived_col))| {
2087 let src_col_idx = source_schema
2088 .fields()
2089 .iter()
2090 .position(|f| f.name() == derived_col);
2091 match src_col_idx {
2092 Some(idx) => {
2093 let src_key = extract_scalar_key(src_batch, &[idx], src_row);
2094 src_key.first() == Some(&body_values[i].1)
2095 }
2096 None => false,
2097 }
2098 },
2099 );
2100 if matches {
2101 let fact_hash = format!(
2102 "{:?}",
2103 extract_scalar_key(src_batch, &all_src_indices, src_row)
2104 )
2105 .into_bytes();
2106 inputs.push(ProofTerm {
2107 source_rule: binding.rule_name.clone(),
2108 base_fact_id: fact_hash,
2109 });
2110 }
2111 }
2112 }
2113 }
2114
2115 inputs
2116}
2117
2118fn collect_is_ref_inputs_for_body_row(
2140 rule: &FixpointRulePlan,
2141 delta_batch: &RecordBatch,
2142 row_idx: usize,
2143 registry: &Arc<DerivedScanRegistry>,
2144) -> Vec<ProofTerm> {
2145 let mut combined: Vec<ProofTerm> = Vec::new();
2146 for clause_index in 0..rule.clauses.len() {
2147 let part = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
2148 combined.extend(part);
2149 }
2150 combined
2151}
2152
2153#[expect(
2172 dead_code,
2173 reason = "Fields accessed via SharedLineageInfo in detect_shared_lineage"
2174)]
2175pub(crate) struct SharedGroupRow {
2176 pub fact_hash: Vec<u8>,
2177 pub lineage: HashSet<Vec<u8>>,
2178}
2179
2180pub(crate) struct SharedLineageInfo {
2182 pub shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>>,
2184}
2185
2186pub(crate) fn fact_hash_key(batch: &RecordBatch, all_indices: &[usize], row_idx: usize) -> Vec<u8> {
2188 format!("{:?}", extract_scalar_key(batch, all_indices, row_idx)).into_bytes()
2189}
2190
2191fn detect_shared_lineage(
2194 rule: &FixpointRulePlan,
2195 pre_fold_facts: &[RecordBatch],
2196 tracker: &Arc<ProvenanceStore>,
2197 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2198 semiring_kind: SemiringKind,
2199) -> Option<SharedLineageInfo> {
2200 use uni_locy::{RuntimeWarning, RuntimeWarningCode};
2201
2202 let has_prob_fold = rule
2207 .fold_bindings
2208 .iter()
2209 .any(|fb| fb.aggregate.is_probability_aggregate());
2210 if !has_prob_fold {
2211 return None;
2212 }
2213
2214 let key_indices = &rule.key_column_indices;
2216 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2217
2218 let mut groups: HashMap<Vec<ScalarKey>, Vec<Vec<u8>>> = HashMap::new();
2219 for batch in pre_fold_facts {
2220 for row_idx in 0..batch.num_rows() {
2221 let key = extract_scalar_key(batch, key_indices, row_idx);
2222 let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
2223 groups.entry(key).or_default().push(fact_hash);
2224 }
2225 }
2226
2227 let mut shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>> = HashMap::new();
2228 let mut any_shared = false;
2229
2230 for (key, fact_hashes) in &groups {
2232 if fact_hashes.len() < 2 {
2233 continue;
2234 }
2235
2236 let mut has_inputs = false;
2238 let mut per_row_bases: Vec<HashSet<Vec<u8>>> = Vec::new();
2239 for fh in fact_hashes {
2240 let bases = compute_lineage(fh, tracker, &mut HashSet::new());
2241 if let Some(entry) = tracker.lookup(fh)
2242 && !entry.support.is_empty()
2243 {
2244 has_inputs = true;
2245 }
2246 per_row_bases.push(bases);
2247 }
2248
2249 let shared_found = if has_inputs {
2250 let mut found = false;
2252 'outer: for i in 0..per_row_bases.len() {
2253 for j in (i + 1)..per_row_bases.len() {
2254 if !per_row_bases[i].is_disjoint(&per_row_bases[j]) {
2255 found = true;
2256 break 'outer;
2257 }
2258 }
2259 }
2260 found
2261 } else {
2262 fact_hashes.iter().any(|fh| {
2265 tracker.lookup(fh).is_some_and(|entry| {
2266 rule.clauses
2267 .get(entry.clause_index)
2268 .is_some_and(|clause| clause.is_ref_bindings.iter().any(|b| !b.negated))
2269 })
2270 })
2271 };
2272
2273 if shared_found {
2274 any_shared = true;
2275 let rows: Vec<SharedGroupRow> = fact_hashes
2277 .iter()
2278 .zip(per_row_bases)
2279 .map(|(fh, bases)| SharedGroupRow {
2280 fact_hash: fh.clone(),
2281 lineage: bases,
2282 })
2283 .collect();
2284 shared_groups.insert(key.clone(), rows);
2285 }
2286 }
2287
2288 {
2294 let mut input_to_groups: HashMap<Vec<u8>, HashSet<Vec<ScalarKey>>> = HashMap::new();
2295 for (key, fact_hashes) in &groups {
2296 for fh in fact_hashes {
2297 if let Some(entry) = tracker.lookup(fh) {
2298 for input in &entry.support {
2299 input_to_groups
2300 .entry(input.base_fact_id.clone())
2301 .or_default()
2302 .insert(key.clone());
2303 }
2304 }
2305 }
2306 }
2307 let has_cross_group = input_to_groups.values().any(|g| g.len() > 1);
2308 if has_cross_group && let Ok(mut warnings) = warnings_slot.write() {
2309 let already_warned = warnings.iter().any(|w| {
2310 w.code == RuntimeWarningCode::CrossGroupCorrelationNotExact
2311 && w.rule_name == rule.name
2312 });
2313 if !already_warned {
2314 let example =
2318 input_to_groups
2319 .iter()
2320 .find(|(_, g)| g.len() > 1)
2321 .map(|(input, groups)| {
2322 let short = input
2323 .iter()
2324 .take(8)
2325 .map(|b| format!("{:02x}", b))
2326 .collect::<String>();
2327 let mut group_strs: Vec<String> =
2328 groups.iter().map(|k| format!("{:?}", k)).collect();
2329 group_strs.sort();
2330 format!(
2331 "input {} shared by groups [{}]",
2332 short,
2333 group_strs.join(", ")
2334 )
2335 });
2336 let shared_variable_count =
2342 input_to_groups.values().filter(|g| g.len() > 1).count();
2343 warnings.push(RuntimeWarning {
2344 code: RuntimeWarningCode::CrossGroupCorrelationNotExact,
2345 message: format!(
2346 "Rule '{}': {} IS-ref base fact(s) are shared across different \
2347 KEY groups. BDD corrects per-group probabilities but cannot \
2348 account for cross-group correlations.",
2349 rule.name, shared_variable_count
2350 ),
2351 rule_name: rule.name.clone(),
2352 variable_count: Some(shared_variable_count),
2353 key_group: example,
2354 });
2355 }
2356 }
2357 }
2358
2359 if any_shared {
2360 let suppress_under_topk = matches!(semiring_kind, SemiringKind::TopKProofs { .. });
2370 if !suppress_under_topk && let Ok(mut warnings) = warnings_slot.write() {
2371 let already_warned = warnings.iter().any(|w| {
2372 w.code == RuntimeWarningCode::SharedProbabilisticDependency
2373 && w.rule_name == rule.name
2374 });
2375 if !already_warned {
2376 warnings.push(RuntimeWarning {
2377 code: RuntimeWarningCode::SharedProbabilisticDependency,
2378 message: format!(
2379 "Rule '{}' aggregates with MNOR/MPROD but some proof paths \
2380 share intermediate facts, violating the independence assumption. \
2381 Results may overestimate probability.",
2382 rule.name
2383 ),
2384 rule_name: rule.name.clone(),
2385 variable_count: None,
2386 key_group: None,
2387 });
2388 }
2389 }
2390 Some(SharedLineageInfo { shared_groups })
2391 } else {
2392 None
2393 }
2394}
2395
2396#[allow(
2404 clippy::too_many_arguments,
2405 reason = "context bundle would be over-engineering for one call site"
2406)]
2407pub(crate) async fn record_and_detect_lineage_nonrecursive(
2408 rule: &FixpointRulePlan,
2409 tagged_clause_facts: &[(usize, Vec<RecordBatch>)],
2410 tracker: &Arc<ProvenanceStore>,
2411 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2412 registry: &Arc<DerivedScanRegistry>,
2413 top_k_proofs: usize,
2414 classifiers: ClassifierRefs<'_>,
2415 semiring_kind: SemiringKind,
2416) -> Option<SharedLineageInfo> {
2417 let classifier_registry = classifiers.registry;
2418 let classifier_cache = classifiers.cache;
2419 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2420
2421 let base_probs = if top_k_proofs > 0 {
2423 tracker.base_fact_probs()
2424 } else {
2425 HashMap::new()
2426 };
2427
2428 let mut topk_acc = TopKProofAccumulator::new();
2429
2430 for (clause_index, batches) in tagged_clause_facts {
2432 for batch in batches {
2433 for row_idx in 0..batch.num_rows() {
2434 let row_hash = fact_hash_key(batch, &all_indices, row_idx);
2435 let fact_row = batch_row_to_value_map(batch, row_idx);
2436
2437 let support = collect_is_ref_inputs(rule, *clause_index, batch, row_idx, registry);
2438
2439 let proof_probability = if top_k_proofs > 0 {
2440 compute_proof_probability(&support, &base_probs)
2441 } else {
2442 None
2443 };
2444
2445 let entry = ProvenanceAnnotation {
2446 rule_name: rule.name.clone(),
2447 clause_index: *clause_index,
2448 support,
2449 along_values: {
2450 let along_names: Vec<String> = rule
2451 .clauses
2452 .get(*clause_index)
2453 .map(|c| c.along_bindings.clone())
2454 .unwrap_or_default();
2455 along_names
2456 .iter()
2457 .filter_map(|name| {
2458 fact_row.get(name).map(|v| (name.clone(), v.clone()))
2459 })
2460 .collect()
2461 },
2462 iteration: 0,
2463 fact_row: fact_row.clone(),
2464 proof_probability,
2465 neural_calls: collect_neural_calls_for_row(
2466 rule,
2467 *clause_index,
2468 &fact_row,
2469 classifier_registry,
2470 classifier_cache,
2471 classifiers.provenance_store,
2472 )
2473 .await,
2474 };
2475 if top_k_proofs > 0 {
2476 topk_acc.accumulate(&entry, &row_hash);
2477 tracker.record_top_k(row_hash, entry, top_k_proofs);
2478 } else {
2479 tracker.record(row_hash, entry);
2480 }
2481 }
2482 }
2483 }
2484
2485 topk_acc.emit_warning_if_any(rule, top_k_proofs, warnings_slot);
2486
2487 let all_facts: Vec<RecordBatch> = tagged_clause_facts
2489 .iter()
2490 .flat_map(|(_, batches)| batches.iter().cloned())
2491 .collect();
2492 detect_shared_lineage(rule, &all_facts, tracker, warnings_slot, semiring_kind)
2493}
2494
2495pub(crate) fn apply_exact_wmc(
2503 pre_fold_facts: Vec<RecordBatch>,
2504 rule: &FixpointRulePlan,
2505 shared_info: &SharedLineageInfo,
2506 tracker: &Arc<ProvenanceStore>,
2507 max_bdd_variables: usize,
2508 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2509 approximate_slot: &Arc<StdRwLock<HashMap<String, Vec<String>>>>,
2510) -> DFResult<Vec<RecordBatch>> {
2511 use crate::query::df_graph::locy_bdd::{SemiringOp, weighted_model_count};
2512 use uni_locy::{RuntimeWarning, RuntimeWarningCode};
2513
2514 let prob_fold = rule
2518 .fold_bindings
2519 .iter()
2520 .find(|fb| fb.aggregate.is_probability_aggregate());
2521 let prob_fold = match prob_fold {
2522 Some(f) => f,
2523 None => return Ok(pre_fold_facts),
2524 };
2525 let semiring_op = if prob_fold.aggregate.is_noisy_or() {
2526 SemiringOp::Disjunction
2527 } else {
2528 SemiringOp::Conjunction
2529 };
2530 let prob_col_idx = prob_fold.input_col_index;
2531 let prob_col_name = rule.yield_schema.field(prob_col_idx).name().clone();
2532
2533 let key_indices = &rule.key_column_indices;
2534 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2535
2536 let shared_keys: HashSet<Vec<ScalarKey>> = shared_info.shared_groups.keys().cloned().collect();
2538
2539 struct GroupAccum {
2542 base_facts: Vec<HashSet<Vec<u8>>>,
2543 base_probs: HashMap<Vec<u8>, f64>,
2544 representative: (usize, usize),
2546 row_locations: Vec<(usize, usize)>,
2547 }
2548
2549 let mut group_accums: HashMap<Vec<ScalarKey>, GroupAccum> = HashMap::new();
2550 let mut non_shared_rows: Vec<(usize, usize)> = Vec::new(); for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
2553 for row_idx in 0..batch.num_rows() {
2554 let key = extract_scalar_key(batch, key_indices, row_idx);
2555 if shared_keys.contains(&key) {
2556 let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
2557 let bases = compute_lineage(&fact_hash, tracker, &mut HashSet::new());
2558
2559 let accum = group_accums.entry(key).or_insert_with(|| GroupAccum {
2560 base_facts: Vec::new(),
2561 base_probs: HashMap::new(),
2562 representative: (batch_idx, row_idx),
2563 row_locations: Vec::new(),
2564 });
2565
2566 for bf in &bases {
2568 if !accum.base_probs.contains_key(bf)
2569 && let Some(entry) = tracker.lookup(bf)
2570 && let Some(val) = entry.fact_row.get(&prob_col_name)
2571 && let Some(p) = value_to_f64(val)
2572 {
2573 accum.base_probs.insert(bf.clone(), p);
2574 }
2575 }
2576
2577 accum.base_facts.push(bases);
2578 accum.row_locations.push((batch_idx, row_idx));
2579 } else {
2580 non_shared_rows.push((batch_idx, row_idx));
2581 }
2582 }
2583 }
2584
2585 let mut keep_rows: HashSet<(usize, usize)> = HashSet::new();
2588 let mut overrides: HashMap<(usize, usize), f64> = HashMap::new();
2590
2591 for &loc in &non_shared_rows {
2593 keep_rows.insert(loc);
2594 }
2595
2596 for (key, accum) in &group_accums {
2597 let bdd_result = weighted_model_count(
2598 &accum.base_facts,
2599 &accum.base_probs,
2600 semiring_op,
2601 max_bdd_variables,
2602 );
2603
2604 if bdd_result.approximated {
2605 if let Ok(mut warnings) = warnings_slot.write() {
2607 let key_desc = format!("{key:?}");
2608 let already_warned = warnings.iter().any(|w| {
2609 w.code == RuntimeWarningCode::BddLimitExceeded
2610 && w.rule_name == rule.name
2611 && w.key_group.as_deref() == Some(&key_desc)
2612 });
2613 if !already_warned {
2614 warnings.push(RuntimeWarning {
2615 code: RuntimeWarningCode::BddLimitExceeded,
2616 message: format!(
2617 "Rule '{}': BDD variable limit exceeded ({} > {}). \
2618 Falling back to independence-mode result.",
2619 rule.name, bdd_result.variable_count, max_bdd_variables
2620 ),
2621 rule_name: rule.name.clone(),
2622 variable_count: Some(bdd_result.variable_count),
2623 key_group: Some(key_desc),
2624 });
2625 }
2626 }
2627 if let Ok(mut approx) = approximate_slot.write() {
2628 let key_desc = format!("{key:?}");
2629 approx.entry(rule.name.clone()).or_default().push(key_desc);
2630 }
2631 for &loc in &accum.row_locations {
2633 keep_rows.insert(loc);
2634 }
2635 } else {
2636 keep_rows.insert(accum.representative);
2638 overrides.insert(accum.representative, bdd_result.probability);
2639 }
2640 }
2641
2642 let mut result_batches = Vec::new();
2644 for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
2645 let kept_indices: Vec<usize> = (0..batch.num_rows())
2646 .filter(|&row_idx| keep_rows.contains(&(batch_idx, row_idx)))
2647 .collect();
2648
2649 if kept_indices.is_empty() {
2650 continue;
2651 }
2652
2653 let indices = arrow::array::UInt32Array::from(
2654 kept_indices.iter().map(|&i| i as u32).collect::<Vec<_>>(),
2655 );
2656 let mut columns: Vec<arrow::array::ArrayRef> = batch
2657 .columns()
2658 .iter()
2659 .map(|col| arrow::compute::take(col, &indices, None))
2660 .collect::<Result<Vec<_>, _>>()
2661 .map_err(arrow_err)?;
2662
2663 let override_map: Vec<Option<f64>> = kept_indices
2665 .iter()
2666 .map(|&row_idx| overrides.get(&(batch_idx, row_idx)).copied())
2667 .collect();
2668
2669 if override_map.iter().any(|o| o.is_some()) && prob_col_idx < columns.len() {
2670 let existing_prob = columns[prob_col_idx]
2672 .as_any()
2673 .downcast_ref::<arrow::array::Float64Array>();
2674 let new_values: Vec<f64> = override_map
2675 .iter()
2676 .enumerate()
2677 .map(|(i, ov)| match ov {
2678 Some(p) => *p,
2679 None => existing_prob.map(|arr| arr.value(i)).unwrap_or(0.0),
2680 })
2681 .collect();
2682 columns[prob_col_idx] = Arc::new(arrow::array::Float64Array::from(new_values));
2683 }
2684
2685 let result_batch = RecordBatch::try_new(batch.schema(), columns).map_err(arrow_err)?;
2686 result_batches.push(result_batch);
2687 }
2688
2689 Ok(result_batches)
2690}
2691
2692fn value_to_f64(val: &uni_common::Value) -> Option<f64> {
2694 match val {
2695 uni_common::Value::Float(f) => Some(*f),
2696 uni_common::Value::Int(i) => Some(*i as f64),
2697 _ => None,
2698 }
2699}
2700
2701fn compute_lineage(
2708 fact_hash: &[u8],
2709 tracker: &Arc<ProvenanceStore>,
2710 visited: &mut HashSet<Vec<u8>>,
2711) -> HashSet<Vec<u8>> {
2712 if !visited.insert(fact_hash.to_vec()) {
2713 return HashSet::new(); }
2715
2716 match tracker.lookup(fact_hash) {
2717 Some(entry) if !entry.support.is_empty() => {
2718 let mut bases = HashSet::new();
2719 for input in &entry.support {
2720 let child_bases = compute_lineage(&input.base_fact_id, tracker, visited);
2721 bases.extend(child_bases);
2722 }
2723 bases
2724 }
2725 _ => {
2726 let mut set = HashSet::new();
2728 set.insert(fact_hash.to_vec());
2729 set
2730 }
2731 }
2732}
2733
2734fn find_clause_for_row(
2739 delta_batch: &RecordBatch,
2740 row_idx: usize,
2741 all_indices: &[usize],
2742 clause_candidates: &[Vec<RecordBatch>],
2743) -> usize {
2744 let target_key = extract_scalar_key(delta_batch, all_indices, row_idx);
2745 for (clause_idx, batches) in clause_candidates.iter().enumerate() {
2746 for batch in batches {
2747 if batch.num_columns() != all_indices.len() {
2748 continue;
2749 }
2750 for r in 0..batch.num_rows() {
2751 if extract_scalar_key(batch, all_indices, r) == target_key {
2752 return clause_idx;
2753 }
2754 }
2755 }
2756 }
2757 0
2758}
2759
2760fn batch_row_to_value_map(
2762 batch: &RecordBatch,
2763 row_idx: usize,
2764) -> std::collections::HashMap<String, Value> {
2765 use uni_store::storage::arrow_convert::arrow_to_value;
2766
2767 let schema = batch.schema();
2768 schema
2769 .fields()
2770 .iter()
2771 .enumerate()
2772 .map(|(col_idx, field)| {
2773 let col = batch.column(col_idx);
2774 let val = arrow_to_value(col.as_ref(), row_idx, None);
2775 (field.name().clone(), val)
2776 })
2777 .collect()
2778}
2779
2780pub fn apply_anti_join(
2785 batches: Vec<RecordBatch>,
2786 neg_facts: &[RecordBatch],
2787 left_col: &str,
2788 right_col: &str,
2789) -> datafusion::error::Result<Vec<RecordBatch>> {
2790 use arrow::compute::filter_record_batch;
2791 use arrow_array::{Array as _, BooleanArray, UInt64Array};
2792
2793 let mut banned: std::collections::HashSet<u64> = std::collections::HashSet::new();
2795 for batch in neg_facts {
2796 let Ok(idx) = batch.schema().index_of(right_col) else {
2797 continue;
2798 };
2799 let arr = batch.column(idx);
2800 let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2801 continue;
2802 };
2803 for i in 0..vids.len() {
2804 if !vids.is_null(i) {
2805 banned.insert(vids.value(i));
2806 }
2807 }
2808 }
2809
2810 if banned.is_empty() {
2811 return Ok(batches);
2812 }
2813
2814 let mut result = Vec::new();
2816 for batch in batches {
2817 let Ok(idx) = batch.schema().index_of(left_col) else {
2818 result.push(batch);
2819 continue;
2820 };
2821 let arr = batch.column(idx);
2822 let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2823 result.push(batch);
2824 continue;
2825 };
2826 let keep: Vec<bool> = (0..vids.len())
2827 .map(|i| vids.is_null(i) || !banned.contains(&vids.value(i)))
2828 .collect();
2829 let keep_arr = BooleanArray::from(keep);
2830 let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
2831 if filtered.num_rows() > 0 {
2832 result.push(filtered);
2833 }
2834 }
2835 Ok(result)
2836}
2837
2838#[allow(clippy::too_many_arguments)]
2859pub(crate) async fn apply_model_invocations(
2860 batches: Vec<RecordBatch>,
2861 invocations: &[uni_locy::ModelInvocation],
2862 registry: &Arc<ClassifierRegistry>,
2863 cache: Option<&Arc<uni_locy::ModelInvocationCache>>,
2864 provenance_store: Option<&Arc<uni_locy::NeuralProvenanceStore>>,
2865 path_context_handles: &HashMap<
2866 String,
2867 crate::query::df_graph::locy_model_invoke::PathContextHandle,
2868 >,
2869 xervo_runtime: &crate::query::df_graph::locy_model_invoke::XervoRuntimeHandle,
2870 graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
2871) -> DFResult<Vec<RecordBatch>> {
2872 use uni_locy::ClassifyInput;
2873 if batches.is_empty() || invocations.is_empty() {
2874 return Ok(batches);
2875 }
2876 let semantic_match_embeddings =
2880 pre_embed_semantic_match_queries(invocations, xervo_runtime).await?;
2881 let graph_feature_maps = precompute_graph_feature_maps(invocations, graph_algo).await?;
2886 let neighbor_feature_maps =
2887 precompute_neighbor_feature_maps(invocations, &batches, graph_algo).await?;
2888 let mut out_batches = Vec::with_capacity(batches.len());
2889 for batch in batches {
2890 let mut current = batch;
2891 for invocation in invocations {
2892 let classifier = registry.get(&invocation.model_name).ok_or_else(|| {
2893 datafusion::error::DataFusionError::Execution(format!(
2894 "neural classifier '{}' not registered; \
2895 add it to LocyConfig::classifier_registry",
2896 invocation.model_name
2897 ))
2898 })?;
2899
2900 let resolvers = build_feature_resolvers(
2912 ¤t,
2913 invocation,
2914 path_context_handles,
2915 &semantic_match_embeddings,
2916 &graph_feature_maps,
2917 &neighbor_feature_maps,
2918 )?;
2919
2920 let n_rows = current.num_rows();
2922 let mut inputs: Vec<ClassifyInput> = Vec::with_capacity(n_rows);
2923 let mut input_hashes: Vec<u64> = Vec::with_capacity(n_rows);
2924 for row_idx in 0..n_rows {
2925 let mut features = std::collections::HashMap::new();
2926 for resolver in &resolvers {
2927 let value = resolver.eval_row(¤t, row_idx)?;
2928 features.insert(resolver.binding_name.clone(), value);
2929 }
2930 let input = ClassifyInput { features };
2931 input_hashes.push(input.stable_hash());
2932 inputs.push(input);
2933 }
2934
2935 let mut probs: Vec<f64> = vec![0.0; n_rows];
2939 let mut miss_inputs: Vec<ClassifyInput> = Vec::new();
2940 let mut miss_row_indices: Vec<usize> = Vec::new();
2941 if let Some(c) = cache {
2942 for (row_idx, h) in input_hashes.iter().enumerate() {
2943 match c.get(&invocation.model_name, *h) {
2944 Some(v) => probs[row_idx] = v,
2945 None => {
2946 miss_row_indices.push(row_idx);
2947 miss_inputs.push(inputs[row_idx].clone());
2948 }
2949 }
2950 }
2951 } else {
2952 miss_row_indices = (0..n_rows).collect();
2953 miss_inputs = inputs.clone();
2954 }
2955
2956 let calibrator = classifier.get_calibrator();
2965 let (miss_raws, miss_calibrated) = if miss_inputs.is_empty() {
2966 (Vec::new(), Vec::new())
2967 } else if calibrator.is_some() {
2968 let pairs = classifier
2969 .raw_and_calibrated(&miss_inputs)
2970 .await
2971 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2972 if pairs.len() != miss_inputs.len() {
2973 return Err(datafusion::error::DataFusionError::Execution(format!(
2974 "classifier '{}' raw_and_calibrated returned {} outputs for {} inputs",
2975 invocation.model_name,
2976 pairs.len(),
2977 miss_inputs.len()
2978 )));
2979 }
2980 let raws: Vec<f64> = pairs.iter().map(|(r, _)| *r).collect();
2981 let cals: Vec<f64> = pairs.iter().map(|(r, c)| c.unwrap_or(*r)).collect();
2982 (raws, cals)
2983 } else {
2984 let r = classifier
2985 .classify(&miss_inputs)
2986 .await
2987 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2988 if r.len() != miss_inputs.len() {
2989 return Err(datafusion::error::DataFusionError::Execution(format!(
2990 "classifier '{}' returned {} outputs for {} inputs",
2991 invocation.model_name,
2992 r.len(),
2993 miss_inputs.len()
2994 )));
2995 }
2996 (r.clone(), r)
2998 };
2999 let mut row_raw: Vec<Option<f64>> = vec![None; n_rows];
3009 for (i, &row_idx) in miss_row_indices.iter().enumerate() {
3010 probs[row_idx] = miss_calibrated[i];
3011 row_raw[row_idx] = Some(miss_raws[i]);
3012 if let Some(c) = cache {
3013 c.insert(
3014 &invocation.model_name,
3015 input_hashes[row_idx],
3016 miss_calibrated[i],
3017 );
3018 }
3019 }
3020
3021 if let Some(store) = provenance_store {
3029 for row_idx in 0..n_rows {
3030 let calibrated_value = probs[row_idx];
3031 let (raw_value, calibrated) = match (row_raw[row_idx], &calibrator) {
3032 (Some(raw), Some(_)) => (raw, Some(calibrated_value)),
3033 (Some(raw), None) => (raw, None),
3034 (None, _) => (
3039 calibrated_value,
3040 calibrator.as_ref().map(|_| calibrated_value),
3041 ),
3042 };
3043 let band = calibrator
3044 .as_ref()
3045 .and_then(|c| c.confidence_band(calibrated_value));
3046 store.record(
3047 &invocation.model_name,
3048 input_hashes[row_idx],
3049 uni_locy::NeuralProvenanceRecord {
3050 raw_probability: raw_value,
3051 calibrated_probability: calibrated,
3052 confidence_band: band,
3053 feature_inputs: inputs[row_idx].features.clone(),
3061 },
3062 );
3063 }
3064 }
3065
3066 let out_col: Arc<dyn arrow_array::Array> =
3071 Arc::new(arrow_array::Float64Array::from(probs));
3072 let schema = current.schema();
3073 let target_idx = schema.index_of(&invocation.output_column).ok();
3074 let mut columns: Vec<Arc<dyn arrow_array::Array>> = current.columns().to_vec();
3075 let mut fields: Vec<Arc<arrow_schema::Field>> =
3076 schema.fields().iter().cloned().collect();
3077 match target_idx {
3078 Some(idx) => {
3079 columns[idx] = out_col;
3080 fields[idx] = Arc::new(arrow_schema::Field::new(
3083 &invocation.output_column,
3084 arrow_schema::DataType::Float64,
3085 true,
3086 ));
3087 }
3088 None => {
3089 columns.push(out_col);
3090 fields.push(Arc::new(arrow_schema::Field::new(
3091 &invocation.output_column,
3092 arrow_schema::DataType::Float64,
3093 true,
3094 )));
3095 }
3096 }
3097 let new_schema = Arc::new(arrow_schema::Schema::new(fields));
3098 current = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
3099 }
3100 out_batches.push(current);
3101 }
3102 Ok(out_batches)
3103}
3104
3105struct FeatureResolver {
3113 binding_name: String,
3114 kind: FeatureResolverKind,
3115}
3116
3117enum FeatureResolverKind {
3118 Direct(usize),
3119 SimilarTo {
3120 left: FeatureValueSrc,
3121 right: FeatureValueSrc,
3122 },
3123 PathContext {
3128 subject_col: usize,
3129 vid_to_value: Arc<HashMap<u64, uni_locy::FeatureValue>>,
3130 },
3131 GraphAlgoScore {
3136 subject_col: usize,
3137 vid_to_score: Arc<HashMap<u64, f64>>,
3138 },
3139 NeighborAggregate {
3145 subject_col: usize,
3146 op: NeighborAgg,
3147 vid_to_values: Arc<HashMap<u64, Vec<f64>>>,
3148 },
3149}
3150
3151#[derive(Debug, Clone, Copy)]
3152enum NeighborAgg {
3153 Avg,
3154 Max,
3155 Sum,
3156}
3157
3158#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3162enum NeighborDirection {
3163 Outgoing,
3164 Incoming,
3165 Both,
3166}
3167
3168impl NeighborDirection {
3169 fn store_directions(self) -> &'static [uni_store::storage::direction::Direction] {
3170 use uni_store::storage::direction::Direction;
3171 match self {
3172 NeighborDirection::Outgoing => &[Direction::Outgoing],
3173 NeighborDirection::Incoming => &[Direction::Incoming],
3174 NeighborDirection::Both => &[Direction::Outgoing, Direction::Incoming],
3175 }
3176 }
3177}
3178
3179impl NeighborAgg {
3180 fn from_fn_name(name: &str) -> Option<Self> {
3181 match name {
3182 "avg_neighbor" => Some(NeighborAgg::Avg),
3183 "max_neighbor" => Some(NeighborAgg::Max),
3184 "sum_neighbor" => Some(NeighborAgg::Sum),
3185 _ => None,
3186 }
3187 }
3188
3189 fn apply(self, values: &[f64]) -> Option<f64> {
3190 if values.is_empty() {
3191 return None;
3192 }
3193 match self {
3194 NeighborAgg::Avg => Some(values.iter().sum::<f64>() / values.len() as f64),
3195 NeighborAgg::Max => values.iter().copied().reduce(f64::max),
3196 NeighborAgg::Sum => Some(values.iter().sum()),
3197 }
3198 }
3199}
3200
3201enum FeatureValueSrc {
3204 Col(usize),
3205 Const(uni_common::Value),
3206}
3207
3208impl FeatureValueSrc {
3209 fn resolve(&self, batch: &RecordBatch, row_idx: usize) -> uni_common::Value {
3210 match self {
3211 FeatureValueSrc::Col(idx) => extract_common_value(batch.column(*idx).as_ref(), row_idx),
3212 FeatureValueSrc::Const(v) => v.clone(),
3213 }
3214 }
3215}
3216
3217impl FeatureResolver {
3218 fn eval_row(&self, batch: &RecordBatch, row_idx: usize) -> DFResult<uni_locy::FeatureValue> {
3219 match &self.kind {
3220 FeatureResolverKind::Direct(idx) => {
3221 Ok(extract_feature_value(batch.column(*idx).as_ref(), row_idx))
3222 }
3223 FeatureResolverKind::SimilarTo { left, right } => {
3224 let lv = left.resolve(batch, row_idx);
3225 let rv = right.resolve(batch, row_idx);
3226 match crate::query::similar_to::eval_similar_to_pure(&lv, &rv) {
3227 Ok(uni_common::Value::Float(f)) => Ok(uni_locy::FeatureValue::Float(f)),
3228 Ok(_) => Ok(uni_locy::FeatureValue::Null),
3229 Err(e) => Err(datafusion::error::DataFusionError::Execution(format!(
3230 "similar_to UDF failed: {e}"
3231 ))),
3232 }
3233 }
3234 FeatureResolverKind::PathContext {
3235 subject_col,
3236 vid_to_value,
3237 } => {
3238 let col = batch.column(*subject_col);
3239 if col.is_null(row_idx) {
3240 return Ok(uni_locy::FeatureValue::Null);
3241 }
3242 if let Some(arr) = col.as_any().downcast_ref::<arrow_array::UInt64Array>() {
3243 let vid = arr.value(row_idx);
3244 Ok(vid_to_value
3245 .get(&vid)
3246 .cloned()
3247 .unwrap_or(uni_locy::FeatureValue::Null))
3248 } else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3249 let vid = arr.value(row_idx) as u64;
3250 Ok(vid_to_value
3251 .get(&vid)
3252 .cloned()
3253 .unwrap_or(uni_locy::FeatureValue::Null))
3254 } else {
3255 Ok(uni_locy::FeatureValue::Null)
3256 }
3257 }
3258 FeatureResolverKind::GraphAlgoScore {
3259 subject_col,
3260 vid_to_score,
3261 } => {
3262 let col = batch.column(*subject_col);
3263 if col.is_null(row_idx) {
3264 return Ok(uni_locy::FeatureValue::Null);
3265 }
3266 let vid_opt: Option<u64> = if let Some(arr) =
3267 col.as_any().downcast_ref::<arrow_array::UInt64Array>()
3268 {
3269 Some(arr.value(row_idx))
3270 } else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3271 Some(arr.value(row_idx) as u64)
3272 } else {
3273 match extract_common_value(col.as_ref(), row_idx) {
3279 uni_common::Value::Node(n) => Some(n.vid.as_u64()),
3280 uni_common::Value::Int(i) => Some(i as u64),
3281 _ => None,
3282 }
3283 };
3284 Ok(vid_opt
3285 .and_then(|v| vid_to_score.get(&v).copied())
3286 .map(uni_locy::FeatureValue::Float)
3287 .unwrap_or(uni_locy::FeatureValue::Null))
3288 }
3289 FeatureResolverKind::NeighborAggregate {
3290 subject_col,
3291 op,
3292 vid_to_values,
3293 } => {
3294 let vid_opt = extract_vid_from_column(batch.column(*subject_col).as_ref(), row_idx);
3295 Ok(vid_opt
3296 .and_then(|v| vid_to_values.get(&v))
3297 .and_then(|values| op.apply(values))
3298 .map(uni_locy::FeatureValue::Float)
3299 .unwrap_or(uni_locy::FeatureValue::Null))
3300 }
3301 }
3302 }
3303}
3304
3305fn extract_vid_from_column(col: &dyn arrow_array::Array, row_idx: usize) -> Option<u64> {
3310 if col.is_null(row_idx) {
3311 return None;
3312 }
3313 if let Some(arr) = col.as_any().downcast_ref::<arrow_array::UInt64Array>() {
3314 return Some(arr.value(row_idx));
3315 }
3316 if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3317 return Some(arr.value(row_idx) as u64);
3318 }
3319 match extract_common_value(col, row_idx) {
3320 uni_common::Value::Node(n) => Some(n.vid.as_u64()),
3321 uni_common::Value::Int(i) => Some(i as u64),
3322 _ => None,
3323 }
3324}
3325
3326#[allow(clippy::too_many_arguments)]
3327fn build_feature_resolvers(
3328 batch: &RecordBatch,
3329 invocation: &uni_locy::ModelInvocation,
3330 path_context_handles: &HashMap<
3331 String,
3332 crate::query::df_graph::locy_model_invoke::PathContextHandle,
3333 >,
3334 semantic_match_embeddings: &HashMap<String, Vec<f32>>,
3335 graph_feature_maps: &HashMap<String, Arc<HashMap<u64, f64>>>,
3336 neighbor_feature_maps: &NeighborFeatureMaps,
3337) -> DFResult<Vec<FeatureResolver>> {
3338 use uni_cypher::ast::Expr;
3339 let schema = batch.schema();
3340 let lookup_col = |name_or_property: String| -> DFResult<usize> {
3341 schema.index_of(&name_or_property).map_err(|_| {
3342 datafusion::error::DataFusionError::Execution(format!(
3343 "feature column '{name_or_property}' not found in clause body output schema"
3344 ))
3345 })
3346 };
3347 let resolve_src = |expr: &Expr| -> DFResult<FeatureValueSrc> {
3352 match expr {
3353 Expr::Variable(name) => {
3354 let col = if schema.index_of(name).is_ok() {
3355 name.clone()
3356 } else {
3357 let vid_name = format!("{}._vid", name);
3358 if schema.index_of(&vid_name).is_ok() {
3359 vid_name
3360 } else {
3361 name.clone()
3362 }
3363 };
3364 Ok(FeatureValueSrc::Col(lookup_col(col)?))
3365 }
3366 Expr::Property(boxed, prop) if matches!(boxed.as_ref(), Expr::Variable(_)) => {
3367 let Expr::Variable(v) = boxed.as_ref() else {
3368 unreachable!()
3369 };
3370 let direct = format!("{}.{}", v, prop);
3371 let col = if schema.index_of(&direct).is_ok() {
3372 direct
3373 } else {
3374 format!("__feat_{}_{}", v, prop)
3375 };
3376 Ok(FeatureValueSrc::Col(lookup_col(col)?))
3377 }
3378 Expr::Literal(lit) => Ok(FeatureValueSrc::Const(lit.to_value())),
3379 Expr::List(items) => {
3380 let mut out = Vec::with_capacity(items.len());
3381 for it in items {
3382 out.push(match it {
3383 Expr::Literal(lit) => lit.to_value(),
3384 _ => uni_common::Value::Null,
3385 });
3386 }
3387 Ok(FeatureValueSrc::Const(uni_common::Value::List(out)))
3388 }
3389 other => Err(datafusion::error::DataFusionError::Execution(format!(
3390 "unsupported feature sub-expression: {other:?}"
3391 ))),
3392 }
3393 };
3394
3395 if let Some(pc) = &invocation.path_context {
3403 let handle = path_context_handles.get(&pc.source_rule).ok_or_else(|| {
3404 datafusion::error::DataFusionError::Execution(format!(
3405 "model '{}' path_context references rule '{}' but no DerivedScanHandle \
3406 was registered; this should never happen — the build_clause path \
3407 mints a handle for every distinct source_rule in the invocation set",
3408 invocation.model_name, pc.source_rule
3409 ))
3410 })?;
3411 let subject_col = schema
3412 .index_of(&format!("{}._vid", pc.subject_var))
3413 .or_else(|_| schema.index_of(&pc.subject_var))
3414 .map_err(|_| {
3415 datafusion::error::DataFusionError::Execution(format!(
3416 "model '{}' path_context: subject column '{}' (or '{0}._vid') not \
3417 in body batch schema",
3418 invocation.model_name, pc.subject_var
3419 ))
3420 })?;
3421 let vid_to_value =
3422 build_path_context_lookup(handle, &pc.subject_var, &pc.column, &invocation.model_name)?;
3423 return Ok(vec![FeatureResolver {
3424 binding_name: pc.column.clone(),
3425 kind: FeatureResolverKind::PathContext {
3426 subject_col,
3427 vid_to_value: Arc::new(vid_to_value),
3428 },
3429 }]);
3430 }
3431
3432 let mut out = Vec::with_capacity(invocation.feature_exprs.len());
3433 for (i, fexpr) in invocation.feature_exprs.iter().enumerate() {
3434 let binding_name = invocation.feature_names[i].clone();
3435 let kind = match fexpr {
3436 Expr::FunctionCall { name, args, .. } if name == "similar_to" => {
3437 if args.len() != 2 {
3438 return Err(datafusion::error::DataFusionError::Execution(format!(
3439 "similar_to expects 2 args, got {}",
3440 args.len()
3441 )));
3442 }
3443 FeatureResolverKind::SimilarTo {
3444 left: resolve_src(&args[0])?,
3445 right: resolve_src(&args[1])?,
3446 }
3447 }
3448 Expr::FunctionCall { name, args, .. } if name == "semantic_match" => {
3449 if args.len() != 2 {
3454 return Err(datafusion::error::DataFusionError::Execution(format!(
3455 "semantic_match expects 2 args, got {}",
3456 args.len()
3457 )));
3458 }
3459 let text = match &args[1] {
3460 Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3461 other => {
3462 return Err(datafusion::error::DataFusionError::Execution(format!(
3463 "semantic_match: 2nd arg must be a string literal, got {other:?}"
3464 )));
3465 }
3466 };
3467 let embedded = semantic_match_embeddings.get(&text).ok_or_else(|| {
3468 datafusion::error::DataFusionError::Execution(format!(
3469 "semantic_match: query text '{text}' was not pre-embedded. \
3470 This is a bug — `apply_model_invocations` should have \
3471 embedded all unique semantic_match texts up front. Most \
3472 likely the Xervo runtime is not configured (configure \
3473 via `LocyConfig::xervo_runtime` or its equivalent)."
3474 ))
3475 })?;
3476 let right_vec: Vec<f32> = embedded.clone();
3477 FeatureResolverKind::SimilarTo {
3478 left: resolve_src(&args[0])?,
3479 right: FeatureValueSrc::Const(uni_common::Value::Vector(right_vec)),
3480 }
3481 }
3482 Expr::FunctionCall { name, args, .. }
3483 if matches!(
3484 name.as_str(),
3485 "degree_centrality"
3486 | "pagerank_score"
3487 | "closeness_centrality"
3488 | "betweenness_centrality"
3489 | "eigenvector_centrality"
3490 | "harmonic_centrality"
3491 | "katz_centrality"
3492 ) =>
3493 {
3494 if args.len() != 1 {
3495 return Err(datafusion::error::DataFusionError::Execution(format!(
3496 "{name} expects 1 arg, got {}",
3497 args.len()
3498 )));
3499 }
3500 let Expr::Variable(v) = &args[0] else {
3501 return Err(datafusion::error::DataFusionError::Execution(format!(
3502 "{name}(...) argument must be a node variable, got {:?}",
3503 args[0]
3504 )));
3505 };
3506 let subject_col = {
3507 let direct = schema.index_of(v).ok();
3508 let vid_name = format!("{}._vid", v);
3509 let vid_col = schema.index_of(&vid_name).ok();
3510 vid_col.or(direct).ok_or_else(|| {
3511 datafusion::error::DataFusionError::Execution(format!(
3512 "{name}: subject column '{v}' (or '{v}._vid') not in body batch schema"
3513 ))
3514 })?
3515 };
3516 let vid_to_score = graph_feature_maps.get(name).cloned().ok_or_else(|| {
3517 datafusion::error::DataFusionError::Execution(format!(
3518 "{name}: pre-computed score map missing. This is a bug — \
3519 `apply_model_invocations` should have called \
3520 `precompute_graph_feature_maps` for every graph-structural \
3521 feature before building resolvers. Most likely the graph \
3522 algorithm registry is not configured."
3523 ))
3524 })?;
3525 FeatureResolverKind::GraphAlgoScore {
3526 subject_col,
3527 vid_to_score,
3528 }
3529 }
3530 Expr::FunctionCall { name, args, .. }
3531 if matches!(
3532 name.as_str(),
3533 "avg_neighbor" | "max_neighbor" | "sum_neighbor"
3534 ) =>
3535 {
3536 if args.len() != 3 && args.len() != 4 {
3537 return Err(datafusion::error::DataFusionError::Execution(format!(
3538 "{name} expects 3 or 4 args, got {}",
3539 args.len()
3540 )));
3541 }
3542 let Expr::Variable(v) = &args[0] else {
3543 return Err(datafusion::error::DataFusionError::Execution(format!(
3544 "{name}(...) first argument must be a node variable, got {:?}",
3545 args[0]
3546 )));
3547 };
3548 let rel_type = match &args[1] {
3549 Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3550 other => {
3551 return Err(datafusion::error::DataFusionError::Execution(format!(
3552 "{name}: 2nd arg must be a string literal (rel-type), got {other:?}"
3553 )));
3554 }
3555 };
3556 let prop_name = match &args[2] {
3557 Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3558 other => {
3559 return Err(datafusion::error::DataFusionError::Execution(format!(
3560 "{name}: 3rd arg must be a string literal (property), got {other:?}"
3561 )));
3562 }
3563 };
3564 let direction_arg = match args.get(3) {
3565 None => NeighborDirection::Outgoing,
3566 Some(Expr::Literal(uni_cypher::ast::CypherLiteral::String(d))) => {
3567 match d.to_uppercase().as_str() {
3568 "OUTGOING" => NeighborDirection::Outgoing,
3569 "INCOMING" => NeighborDirection::Incoming,
3570 "BOTH" => NeighborDirection::Both,
3571 other => {
3572 return Err(datafusion::error::DataFusionError::Execution(
3573 format!(
3574 "{name}: direction must be OUTGOING|INCOMING|BOTH, got '{other}'"
3575 ),
3576 ));
3577 }
3578 }
3579 }
3580 Some(other) => {
3581 return Err(datafusion::error::DataFusionError::Execution(format!(
3582 "{name}: 4th arg must be a string literal (direction), got {other:?}"
3583 )));
3584 }
3585 };
3586 let subject_col = {
3587 let direct = schema.index_of(v).ok();
3588 let vid_name = format!("{}._vid", v);
3589 let vid_col = schema.index_of(&vid_name).ok();
3590 vid_col.or(direct).ok_or_else(|| {
3591 datafusion::error::DataFusionError::Execution(format!(
3592 "{name}: subject column '{v}' (or '{v}._vid') not in body batch schema"
3593 ))
3594 })?
3595 };
3596 let vid_to_values = neighbor_feature_maps
3597 .get(&(rel_type.clone(), prop_name.clone(), direction_arg))
3598 .cloned()
3599 .ok_or_else(|| {
3600 datafusion::error::DataFusionError::Execution(format!(
3601 "{name}: pre-computed neighbor map missing for ({rel_type}, {prop_name}, {direction_arg:?}). \
3602 This is a bug — `apply_model_invocations` should have called \
3603 `precompute_neighbor_feature_maps` for every neighbor-aggregator \
3604 feature before building resolvers."
3605 ))
3606 })?;
3607 let op = NeighborAgg::from_fn_name(name).unwrap();
3608 FeatureResolverKind::NeighborAggregate {
3609 subject_col,
3610 op,
3611 vid_to_values,
3612 }
3613 }
3614 other => match resolve_src(other)? {
3615 FeatureValueSrc::Col(idx) => FeatureResolverKind::Direct(idx),
3616 FeatureValueSrc::Const(_) => {
3617 return Err(datafusion::error::DataFusionError::Execution(format!(
3618 "model '{}' feature must reference a variable or property — got a literal",
3619 invocation.model_name
3620 )));
3621 }
3622 },
3623 };
3624 out.push(FeatureResolver { binding_name, kind });
3625 }
3626 Ok(out)
3627}
3628
3629async fn pre_embed_semantic_match_queries(
3636 invocations: &[uni_locy::ModelInvocation],
3637 xervo_runtime: &crate::query::df_graph::locy_model_invoke::XervoRuntimeHandle,
3638) -> DFResult<HashMap<String, Vec<f32>>> {
3639 use uni_cypher::ast::{CypherLiteral, Expr};
3640 let mut needed: Vec<(String, String)> = Vec::new();
3650 for inv in invocations {
3651 let alias = inv
3652 .embedder_alias
3653 .clone()
3654 .unwrap_or_else(|| "default".to_string());
3655 for fexpr in &inv.feature_exprs {
3656 if let Expr::FunctionCall { name, args, .. } = fexpr
3657 && name == "semantic_match"
3658 && args.len() == 2
3659 && let Expr::Literal(CypherLiteral::String(s)) = &args[1]
3660 {
3661 let tuple = (s.clone(), alias.clone());
3662 if !needed.contains(&tuple) {
3663 needed.push(tuple);
3664 }
3665 }
3666 }
3667 }
3668 if needed.is_empty() {
3669 return Ok(HashMap::new());
3670 }
3671 let runtime = xervo_runtime.as_ref().ok_or_else(|| {
3672 datafusion::error::DataFusionError::Execution(
3673 "semantic_match: Uni-Xervo runtime not configured. Either provide \
3674 one via `LocyConfig::xervo_runtime` (or its equivalent setup \
3675 path) or pre-compute the query embedding and pass it via \
3676 `similar_to(prop, <literal_vector>)`."
3677 .to_string(),
3678 )
3679 })?;
3680 let mut by_alias: HashMap<String, Vec<String>> = HashMap::new();
3683 for (text, alias) in &needed {
3684 by_alias
3685 .entry(alias.clone())
3686 .or_default()
3687 .push(text.clone());
3688 }
3689 let mut out: HashMap<String, Vec<f32>> = HashMap::new();
3690 for (alias, texts) in by_alias {
3691 let embedder = runtime.embedding(&alias).await.map_err(|e| {
3692 datafusion::error::DataFusionError::Execution(format!(
3693 "semantic_match: failed to obtain embedder for alias '{alias}': {e}. \
3694 Register an embedder under that alias in your Uni-Xervo runtime, or \
3695 pre-compute the query embedding and pass via similar_to."
3696 ))
3697 })?;
3698 let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
3699 let embeddings = embedder.embed(text_refs).await.map_err(|e| {
3700 datafusion::error::DataFusionError::Execution(format!(
3701 "semantic_match: embedder '{alias}' call failed: {e}"
3702 ))
3703 })?;
3704 if embeddings.len() != texts.len() {
3705 return Err(datafusion::error::DataFusionError::Execution(format!(
3706 "semantic_match: embedder '{alias}' returned {} vectors for {} queries",
3707 embeddings.len(),
3708 texts.len()
3709 )));
3710 }
3711 for (text, vec) in texts.into_iter().zip(embeddings) {
3712 out.insert(text, vec);
3713 }
3714 }
3715 Ok(out)
3716}
3717
3718async fn precompute_graph_feature_maps(
3731 invocations: &[uni_locy::ModelInvocation],
3732 graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
3733) -> DFResult<HashMap<String, Arc<HashMap<u64, f64>>>> {
3734 use futures::StreamExt;
3735 use uni_algo::algo::procedures::AlgoContext;
3736 use uni_cypher::ast::Expr;
3737
3738 fn procedure_for(fn_name: &str) -> Option<&'static str> {
3741 match fn_name {
3742 "degree_centrality" => Some("uni.algo.degreeCentrality"),
3743 "pagerank_score" => Some("uni.algo.pageRank"),
3744 "closeness_centrality" => Some("uni.algo.closeness"),
3745 "betweenness_centrality" => Some("uni.algo.betweenness"),
3746 "eigenvector_centrality" => Some("uni.algo.eigenvectorCentrality"),
3747 "harmonic_centrality" => Some("uni.algo.harmonicCentrality"),
3748 "katz_centrality" => Some("uni.algo.katzCentrality"),
3749 _ => None,
3750 }
3751 }
3752
3753 let mut needed: Vec<String> = Vec::new();
3757 for inv in invocations {
3758 for fexpr in &inv.feature_exprs {
3759 if let Expr::FunctionCall { name, .. } = fexpr
3760 && procedure_for(name).is_some()
3761 && !needed.contains(name)
3762 {
3763 needed.push(name.clone());
3764 }
3765 }
3766 }
3767 if needed.is_empty() {
3768 return Ok(HashMap::new());
3769 }
3770
3771 let registry = graph_algo.registry.as_ref().ok_or_else(|| {
3772 datafusion::error::DataFusionError::Execution(
3773 "graph-structural FEATURE invoked but no `AlgorithmRegistry` is \
3774 configured. Configure one on `GraphExecutionContext::with_algo_registry`."
3775 .to_string(),
3776 )
3777 })?;
3778 let storage = graph_algo.storage.as_ref().ok_or_else(|| {
3779 datafusion::error::DataFusionError::Execution(
3780 "graph-structural FEATURE invoked but no storage handle was \
3781 threaded into the FEATURE runtime. This is a bug in df_planner."
3782 .to_string(),
3783 )
3784 })?;
3785
3786 let mut out: HashMap<String, Arc<HashMap<u64, f64>>> = HashMap::new();
3787 for fn_name in needed {
3788 let proc_name = procedure_for(&fn_name).unwrap();
3789 let procedure = registry.get(proc_name).ok_or_else(|| {
3790 datafusion::error::DataFusionError::Execution(format!(
3791 "graph-structural FEATURE '{fn_name}' resolves to procedure \
3792 '{proc_name}' which is not in the algorithm registry"
3793 ))
3794 })?;
3795 let args: Vec<serde_json::Value> = vec![
3800 serde_json::Value::Array(Vec::new()),
3801 serde_json::Value::Array(Vec::new()),
3802 ];
3803 let algo_ctx = AlgoContext::new(
3804 storage.clone(),
3805 graph_algo.l0_manager.as_ref().map(Arc::clone),
3806 );
3807 let filled_args = procedure
3828 .signature()
3829 .validate_args(args.clone())
3830 .map_err(|e| {
3831 datafusion::error::DataFusionError::Execution(format!(
3832 "graph-structural FEATURE '{fn_name}': argument validation failed: {e}"
3833 ))
3834 })?;
3835 let projection = uni_algo::algo::procedure_template::build_projection_from_direct_args(
3836 procedure.as_ref(),
3837 &algo_ctx,
3838 &filled_args,
3839 )
3840 .await
3841 .map_err(|e| {
3842 datafusion::error::DataFusionError::Execution(format!(
3843 "graph-structural FEATURE '{fn_name}': projection build failed: {e}"
3844 ))
3845 })?;
3846 let mut stream = procedure.execute_with_projection(algo_ctx, args, projection);
3847 let mut score_map: HashMap<u64, f64> = HashMap::new();
3848 let sig = procedure.signature();
3849 let node_idx = sig
3850 .yields
3851 .iter()
3852 .position(|(n, _)| *n == "nodeId")
3853 .ok_or_else(|| {
3854 datafusion::error::DataFusionError::Execution(format!(
3855 "procedure '{proc_name}' yield schema missing 'nodeId'"
3856 ))
3857 })?;
3858 let score_idx = sig
3863 .yields
3864 .iter()
3865 .position(|(n, _)| *n == "score" || *n == "centrality")
3866 .ok_or_else(|| {
3867 datafusion::error::DataFusionError::Execution(format!(
3868 "procedure '{proc_name}' yield schema missing a numeric score column \
3869 (expected 'score' or 'centrality')"
3870 ))
3871 })?;
3872 while let Some(row_res) = stream.next().await {
3873 let row = row_res.map_err(|e| {
3874 datafusion::error::DataFusionError::Execution(format!(
3875 "graph-structural FEATURE '{fn_name}': procedure '{proc_name}' failed: {e}"
3876 ))
3877 })?;
3878 let vid_v = row.values.get(node_idx);
3879 let score_v = row.values.get(score_idx);
3880 let (Some(vid_v), Some(score_v)) = (vid_v, score_v) else {
3881 continue;
3882 };
3883 let vid = vid_v.as_u64().or_else(|| vid_v.as_i64().map(|i| i as u64));
3884 let score = score_v
3885 .as_f64()
3886 .or_else(|| score_v.as_i64().map(|i| i as f64));
3887 if let (Some(vid), Some(score)) = (vid, score) {
3888 score_map.insert(vid, score);
3889 }
3890 }
3891 out.insert(fn_name, Arc::new(score_map));
3892 }
3893 Ok(out)
3894}
3895
3896type NeighborFeatureMaps =
3922 HashMap<(String, String, NeighborDirection), Arc<HashMap<u64, Vec<f64>>>>;
3923
3924async fn precompute_neighbor_feature_maps(
3925 invocations: &[uni_locy::ModelInvocation],
3926 batches: &[RecordBatch],
3927 graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
3928) -> DFResult<NeighborFeatureMaps> {
3929 use uni_cypher::ast::{CypherLiteral, Expr};
3930
3931 let parse_direction = |arg: Option<&Expr>| -> Option<NeighborDirection> {
3936 match arg {
3937 None => Some(NeighborDirection::Outgoing),
3938 Some(Expr::Literal(CypherLiteral::String(d))) => match d.to_uppercase().as_str() {
3939 "OUTGOING" => Some(NeighborDirection::Outgoing),
3940 "INCOMING" => Some(NeighborDirection::Incoming),
3941 "BOTH" => Some(NeighborDirection::Both),
3942 _ => None,
3943 },
3944 _ => None,
3945 }
3946 };
3947 let mut needed: Vec<(String, String, String, NeighborDirection)> = Vec::new();
3948 for inv in invocations {
3949 for fexpr in &inv.feature_exprs {
3950 if let Expr::FunctionCall { name, args, .. } = fexpr
3951 && NeighborAgg::from_fn_name(name).is_some()
3952 && (args.len() == 3 || args.len() == 4)
3953 && let Expr::Variable(v) = &args[0]
3954 && let Expr::Literal(CypherLiteral::String(rel)) = &args[1]
3955 && let Expr::Literal(CypherLiteral::String(prop)) = &args[2]
3956 && let Some(direction) = parse_direction(args.get(3))
3957 {
3958 let tuple = (v.clone(), rel.clone(), prop.clone(), direction);
3959 if !needed.contains(&tuple) {
3960 needed.push(tuple);
3961 }
3962 }
3963 }
3964 }
3965 if needed.is_empty() {
3966 return Ok(HashMap::new());
3967 }
3968
3969 let storage = graph_algo.storage.as_ref().ok_or_else(|| {
3970 datafusion::error::DataFusionError::Execution(
3971 "neighbor-aggregator FEATURE invoked but no storage handle was \
3972 threaded into the FEATURE runtime. This is a bug in df_planner."
3973 .to_string(),
3974 )
3975 })?;
3976 let property_manager = graph_algo.property_manager.as_ref().ok_or_else(|| {
3977 datafusion::error::DataFusionError::Execution(
3978 "neighbor-aggregator FEATURE invoked but no PropertyManager was \
3979 threaded into the FEATURE runtime. This is a bug in df_planner."
3980 .to_string(),
3981 )
3982 })?;
3983 let query_ctx = graph_algo.l0_buffers.as_ref().map(|bufs| {
3989 uni_store::runtime::context::QueryContext::new_with_pending(
3990 bufs.current.clone(),
3991 bufs.transaction.clone(),
3992 bufs.pending_flush.clone(),
3993 )
3994 });
3995
3996 let mut by_key: HashMap<(String, String, NeighborDirection), Vec<String>> = HashMap::new();
4000 for (subject_var, rel, prop, direction) in needed {
4001 by_key
4002 .entry((rel, prop, direction))
4003 .or_default()
4004 .push(subject_var);
4005 }
4006
4007 let mut out: NeighborFeatureMaps = HashMap::new();
4008 for ((rel_type, prop_name, direction), subject_vars) in by_key {
4009 let schema = storage.schema_manager().schema();
4011 let Some(edge_meta) = schema.edge_types.get(&rel_type) else {
4012 out.insert((rel_type, prop_name, direction), Arc::new(HashMap::new()));
4015 continue;
4016 };
4017 let edge_type_id = edge_meta.id;
4018
4019 let edge_ver = storage.get_edge_version_by_id(edge_type_id);
4022 for dir in direction.store_directions() {
4023 storage
4024 .warm_adjacency(edge_type_id, *dir, edge_ver)
4025 .await
4026 .map_err(|e| {
4027 datafusion::error::DataFusionError::Execution(format!(
4028 "neighbor-aggregator warm_adjacency for '{rel_type}' / {dir:?} failed: {e}"
4029 ))
4030 })?;
4031 }
4032
4033 let mut subject_vids: std::collections::HashSet<u64> = std::collections::HashSet::new();
4036 for subject_var in &subject_vars {
4037 for batch in batches {
4038 let schema = batch.schema();
4039 let col_idx = schema
4040 .index_of(&format!("{}._vid", subject_var))
4041 .ok()
4042 .or_else(|| schema.index_of(subject_var).ok());
4043 let Some(col_idx) = col_idx else { continue };
4044 let col = batch.column(col_idx);
4045 for row in 0..batch.num_rows() {
4046 if let Some(v) = extract_vid_from_column(col.as_ref(), row) {
4047 subject_vids.insert(v);
4048 }
4049 }
4050 }
4051 }
4052
4053 let mut vid_to_values: HashMap<u64, Vec<f64>> = HashMap::new();
4058 let adj = storage.adjacency_manager();
4059 for subject_vid in subject_vids {
4060 let mut neighbors: Vec<(uni_common::core::id::Vid, uni_common::core::id::Eid)> =
4061 Vec::new();
4062 for dir in direction.store_directions() {
4063 neighbors.extend(adj.get_neighbors(
4064 uni_common::core::id::Vid::from(subject_vid),
4065 edge_type_id,
4066 *dir,
4067 ));
4068 }
4069 let mut values: Vec<f64> = Vec::with_capacity(neighbors.len());
4070 for (neighbor_vid, _eid) in neighbors {
4071 let val = property_manager
4072 .get_vertex_prop_with_ctx(neighbor_vid, &prop_name, query_ctx.as_ref())
4073 .await
4074 .map_err(|e| {
4075 datafusion::error::DataFusionError::Execution(format!(
4076 "neighbor-aggregator: failed to read property \
4077 '{prop_name}' on neighbor vid {neighbor_vid:?}: {e}"
4078 ))
4079 })?;
4080 if let Some(f) = val.as_f64()
4081 && !f.is_nan()
4082 {
4083 values.push(f);
4084 }
4085 }
4086 vid_to_values.insert(subject_vid, values);
4087 }
4088 out.insert((rel_type, prop_name, direction), Arc::new(vid_to_values));
4089 }
4090 Ok(out)
4091}
4092
4093fn build_path_context_lookup(
4099 handle: &crate::query::df_graph::locy_model_invoke::PathContextHandle,
4100 _subject_var: &str,
4101 column: &str,
4102 model_name: &str,
4103) -> DFResult<HashMap<u64, uni_locy::FeatureValue>> {
4104 if handle.schema.fields().is_empty() {
4109 return Err(datafusion::error::DataFusionError::Execution(format!(
4110 "model '{model_name}' path_context: source rule has empty yield schema"
4111 )));
4112 }
4113 let subj_idx = 0_usize;
4114 let col_idx = handle.schema.index_of(column).map_err(|_| {
4115 datafusion::error::DataFusionError::Execution(format!(
4116 "model '{model_name}' path_context: column '{column}' not in \
4117 source rule's yield schema (have: {:?})",
4118 handle
4119 .schema
4120 .fields()
4121 .iter()
4122 .map(|f| f.name().clone())
4123 .collect::<Vec<_>>()
4124 ))
4125 })?;
4126 let batches = handle.data.read();
4127 let mut out: HashMap<u64, uni_locy::FeatureValue> = HashMap::new();
4128 for batch in batches.iter() {
4129 let subj_col = batch.column(subj_idx);
4130 let value_col = batch.column(col_idx);
4131 for row in 0..batch.num_rows() {
4132 if subj_col.is_null(row) {
4133 continue;
4134 }
4135 let vid = if let Some(a) = subj_col.as_any().downcast_ref::<arrow_array::UInt64Array>()
4136 {
4137 a.value(row)
4138 } else if let Some(a) = subj_col.as_any().downcast_ref::<arrow_array::Int64Array>() {
4139 a.value(row) as u64
4140 } else {
4141 continue;
4142 };
4143 let v = extract_feature_value(value_col.as_ref(), row);
4144 out.insert(vid, v);
4147 }
4148 }
4149 Ok(out)
4150}
4151
4152fn extract_common_value(col: &dyn arrow_array::Array, row_idx: usize) -> uni_common::Value {
4157 use arrow_array::{BooleanArray, Float64Array, Int64Array, LargeStringArray, StringArray};
4158 if col.is_null(row_idx) {
4159 return uni_common::Value::Null;
4160 }
4161 if let Some(a) = col.as_any().downcast_ref::<Float64Array>() {
4162 return uni_common::Value::Float(a.value(row_idx));
4163 }
4164 if let Some(a) = col.as_any().downcast_ref::<Int64Array>() {
4165 return uni_common::Value::Int(a.value(row_idx));
4166 }
4167 if let Some(a) = col.as_any().downcast_ref::<BooleanArray>() {
4168 return uni_common::Value::Bool(a.value(row_idx));
4169 }
4170 if let Some(a) = col.as_any().downcast_ref::<StringArray>() {
4171 return uni_common::Value::String(a.value(row_idx).to_string());
4172 }
4173 if let Some(a) = col.as_any().downcast_ref::<LargeStringArray>() {
4174 return uni_common::Value::String(a.value(row_idx).to_string());
4175 }
4176 if let Some(b) = col.as_any().downcast_ref::<arrow_array::LargeBinaryArray>() {
4177 let bytes = b.value(row_idx);
4178 if bytes.is_empty() {
4179 return uni_common::Value::Null;
4180 }
4181 return uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null);
4182 }
4183 uni_common::Value::Null
4184}
4185
4186fn extract_feature_value(col: &dyn arrow_array::Array, row_idx: usize) -> uni_locy::FeatureValue {
4187 use arrow_array::{BooleanArray, Float64Array, Int64Array, LargeStringArray, StringArray};
4188 if col.is_null(row_idx) {
4189 return uni_locy::FeatureValue::Null;
4190 }
4191 if let Some(a) = col.as_any().downcast_ref::<Float64Array>() {
4192 return uni_locy::FeatureValue::Float(a.value(row_idx));
4193 }
4194 if let Some(a) = col.as_any().downcast_ref::<Int64Array>() {
4195 return uni_locy::FeatureValue::Int(a.value(row_idx));
4196 }
4197 if let Some(a) = col.as_any().downcast_ref::<BooleanArray>() {
4198 return uni_locy::FeatureValue::Bool(a.value(row_idx));
4199 }
4200 if let Some(a) = col.as_any().downcast_ref::<StringArray>() {
4201 return uni_locy::FeatureValue::String(a.value(row_idx).to_string());
4202 }
4203 if let Some(a) = col.as_any().downcast_ref::<LargeStringArray>() {
4204 return uni_locy::FeatureValue::String(a.value(row_idx).to_string());
4205 }
4206 if let Some(b) = col.as_any().downcast_ref::<arrow_array::LargeBinaryArray>() {
4210 let bytes = b.value(row_idx);
4211 if bytes.is_empty() {
4212 return uni_locy::FeatureValue::Null;
4213 }
4214 let v = uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null);
4215 return match v {
4216 uni_common::Value::Float(f) => uni_locy::FeatureValue::Float(f),
4217 uni_common::Value::Int(i) => uni_locy::FeatureValue::Int(i),
4218 uni_common::Value::Bool(b) => uni_locy::FeatureValue::Bool(b),
4219 uni_common::Value::String(s) => uni_locy::FeatureValue::String(s),
4220 uni_common::Value::Null => uni_locy::FeatureValue::Null,
4221 _ => uni_locy::FeatureValue::Null,
4222 };
4223 }
4224 uni_locy::FeatureValue::Null
4225}
4226
4227pub fn apply_prob_complement(
4234 batches: Vec<RecordBatch>,
4235 neg_facts: &[RecordBatch],
4236 left_col: &str,
4237 right_col: &str,
4238 prob_col: &str,
4239 complement_col_name: &str,
4240) -> datafusion::error::Result<Vec<RecordBatch>> {
4241 use arrow_array::{Array as _, Float64Array, UInt64Array};
4242
4243 let mut prob_map: std::collections::HashMap<u64, f64> = std::collections::HashMap::new();
4245 for batch in neg_facts {
4246 let Ok(vid_idx) = batch.schema().index_of(right_col) else {
4247 continue;
4248 };
4249 let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
4250 continue;
4251 };
4252 let Some(vids) = batch.column(vid_idx).as_any().downcast_ref::<UInt64Array>() else {
4253 continue;
4254 };
4255 let prob_arr = batch.column(prob_idx);
4256 let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
4257 for i in 0..vids.len() {
4258 if !vids.is_null(i) {
4259 let p = probs
4260 .and_then(|arr| {
4261 if arr.is_null(i) {
4262 None
4263 } else {
4264 Some(arr.value(i))
4265 }
4266 })
4267 .unwrap_or(0.0);
4268 prob_map
4271 .entry(vids.value(i))
4272 .and_modify(|existing| {
4273 *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
4274 })
4275 .or_insert(p);
4276 }
4277 }
4278 }
4279
4280 let mut result = Vec::new();
4282 for batch in batches {
4283 let Ok(idx) = batch.schema().index_of(left_col) else {
4284 result.push(batch);
4285 continue;
4286 };
4287 let Some(vids) = batch.column(idx).as_any().downcast_ref::<UInt64Array>() else {
4288 result.push(batch);
4289 continue;
4290 };
4291
4292 let complements: Vec<f64> = (0..vids.len())
4294 .map(|i| {
4295 if vids.is_null(i) {
4296 1.0
4297 } else {
4298 let p = prob_map.get(&vids.value(i)).copied().unwrap_or(0.0);
4299 1.0 - p
4300 }
4301 })
4302 .collect();
4303
4304 let complement_arr = Float64Array::from(complements);
4305
4306 let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
4308 columns.push(std::sync::Arc::new(complement_arr));
4309
4310 let mut fields: Vec<std::sync::Arc<arrow_schema::Field>> =
4311 batch.schema().fields().iter().cloned().collect();
4312 fields.push(std::sync::Arc::new(arrow_schema::Field::new(
4313 complement_col_name,
4314 arrow_schema::DataType::Float64,
4315 true,
4316 )));
4317
4318 let new_schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4319 let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
4320 result.push(new_batch);
4321 }
4322 Ok(result)
4323}
4324
4325pub fn apply_prob_complement_composite(
4332 batches: Vec<RecordBatch>,
4333 neg_facts: &[RecordBatch],
4334 join_cols: &[(String, String)],
4335 prob_col: &str,
4336 complement_col_name: &str,
4337) -> datafusion::error::Result<Vec<RecordBatch>> {
4338 use arrow_array::{Array as _, Float64Array, UInt64Array};
4339
4340 let mut prob_map: HashMap<Vec<u64>, f64> = HashMap::new();
4342 for batch in neg_facts {
4343 let right_indices: Vec<usize> = join_cols
4344 .iter()
4345 .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
4346 .collect();
4347 if right_indices.len() != join_cols.len() {
4348 continue;
4349 }
4350 let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
4351 continue;
4352 };
4353 let prob_arr = batch.column(prob_idx);
4354 let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
4355 for row in 0..batch.num_rows() {
4356 let mut key = Vec::with_capacity(right_indices.len());
4357 let mut valid = true;
4358 for &ci in &right_indices {
4359 let col = batch.column(ci);
4360 if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
4361 if vids.is_null(row) {
4362 valid = false;
4363 break;
4364 }
4365 key.push(vids.value(row));
4366 } else {
4367 valid = false;
4368 break;
4369 }
4370 }
4371 if !valid {
4372 continue;
4373 }
4374 let p = probs
4375 .and_then(|arr| {
4376 if arr.is_null(row) {
4377 None
4378 } else {
4379 Some(arr.value(row))
4380 }
4381 })
4382 .unwrap_or(0.0);
4383 prob_map
4385 .entry(key)
4386 .and_modify(|existing| {
4387 *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
4388 })
4389 .or_insert(p);
4390 }
4391 }
4392
4393 let mut result = Vec::new();
4395 for batch in batches {
4396 let left_indices: Vec<usize> = join_cols
4397 .iter()
4398 .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
4399 .collect();
4400 if left_indices.len() != join_cols.len() {
4401 result.push(batch);
4402 continue;
4403 }
4404 let all_u64 = left_indices.iter().all(|&ci| {
4405 batch
4406 .column(ci)
4407 .as_any()
4408 .downcast_ref::<UInt64Array>()
4409 .is_some()
4410 });
4411 if !all_u64 {
4412 result.push(batch);
4413 continue;
4414 }
4415
4416 let complements: Vec<f64> = (0..batch.num_rows())
4417 .map(|row| {
4418 let mut key = Vec::with_capacity(left_indices.len());
4419 for &ci in &left_indices {
4420 let vids = batch
4421 .column(ci)
4422 .as_any()
4423 .downcast_ref::<UInt64Array>()
4424 .unwrap();
4425 if vids.is_null(row) {
4426 return 1.0;
4427 }
4428 key.push(vids.value(row));
4429 }
4430 let p = prob_map.get(&key).copied().unwrap_or(0.0);
4431 1.0 - p
4432 })
4433 .collect();
4434
4435 let complement_arr = Float64Array::from(complements);
4436 let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
4437 columns.push(Arc::new(complement_arr));
4438
4439 let mut fields: Vec<Arc<arrow_schema::Field>> =
4440 batch.schema().fields().iter().cloned().collect();
4441 fields.push(Arc::new(arrow_schema::Field::new(
4442 complement_col_name,
4443 arrow_schema::DataType::Float64,
4444 true,
4445 )));
4446
4447 let new_schema = Arc::new(arrow_schema::Schema::new(fields));
4448 let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
4449 result.push(new_batch);
4450 }
4451 Ok(result)
4452}
4453
4454pub fn apply_anti_join_composite(
4460 batches: Vec<RecordBatch>,
4461 neg_facts: &[RecordBatch],
4462 join_cols: &[(String, String)],
4463) -> datafusion::error::Result<Vec<RecordBatch>> {
4464 use arrow::compute::filter_record_batch;
4465 use arrow_array::{Array as _, BooleanArray, UInt64Array};
4466
4467 let mut banned: HashSet<Vec<u64>> = HashSet::new();
4469 for batch in neg_facts {
4470 let right_indices: Vec<usize> = join_cols
4471 .iter()
4472 .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
4473 .collect();
4474 if right_indices.len() != join_cols.len() {
4475 continue;
4476 }
4477 for row in 0..batch.num_rows() {
4478 let mut key = Vec::with_capacity(right_indices.len());
4479 let mut valid = true;
4480 for &ci in &right_indices {
4481 let col = batch.column(ci);
4482 if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
4483 if vids.is_null(row) {
4484 valid = false;
4485 break;
4486 }
4487 key.push(vids.value(row));
4488 } else {
4489 valid = false;
4490 break;
4491 }
4492 }
4493 if valid {
4494 banned.insert(key);
4495 }
4496 }
4497 }
4498
4499 if banned.is_empty() {
4500 return Ok(batches);
4501 }
4502
4503 let mut result = Vec::new();
4505 for batch in batches {
4506 let left_indices: Vec<usize> = join_cols
4507 .iter()
4508 .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
4509 .collect();
4510 if left_indices.len() != join_cols.len() {
4511 result.push(batch);
4512 continue;
4513 }
4514 let all_u64 = left_indices.iter().all(|&ci| {
4515 batch
4516 .column(ci)
4517 .as_any()
4518 .downcast_ref::<UInt64Array>()
4519 .is_some()
4520 });
4521 if !all_u64 {
4522 result.push(batch);
4523 continue;
4524 }
4525
4526 let keep: Vec<bool> = (0..batch.num_rows())
4527 .map(|row| {
4528 let mut key = Vec::with_capacity(left_indices.len());
4529 for &ci in &left_indices {
4530 let vids = batch
4531 .column(ci)
4532 .as_any()
4533 .downcast_ref::<UInt64Array>()
4534 .unwrap();
4535 if vids.is_null(row) {
4536 return true; }
4538 key.push(vids.value(row));
4539 }
4540 !banned.contains(&key)
4541 })
4542 .collect();
4543 let keep_arr = BooleanArray::from(keep);
4544 let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
4545 if filtered.num_rows() > 0 {
4546 result.push(filtered);
4547 }
4548 }
4549 Ok(result)
4550}
4551
4552pub fn multiply_prob_factors(
4563 batches: Vec<RecordBatch>,
4564 prob_col: Option<&str>,
4565 complement_cols: &[String],
4566) -> datafusion::error::Result<Vec<RecordBatch>> {
4567 use arrow_array::{Array as _, Float64Array};
4568
4569 let mut result = Vec::with_capacity(batches.len());
4570
4571 for batch in batches {
4572 if batch.num_rows() == 0 {
4573 let keep: Vec<usize> = batch
4575 .schema()
4576 .fields()
4577 .iter()
4578 .enumerate()
4579 .filter(|(_, f)| !complement_cols.contains(f.name()))
4580 .map(|(i, _)| i)
4581 .collect();
4582 let fields: Vec<_> = keep
4583 .iter()
4584 .map(|&i| batch.schema().field(i).clone())
4585 .collect();
4586 let cols: Vec<_> = keep.iter().map(|&i| batch.column(i).clone()).collect();
4587 let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4588 result.push(
4589 RecordBatch::try_new(schema, cols).map_err(|e| {
4590 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
4591 })?,
4592 );
4593 continue;
4594 }
4595
4596 let num_rows = batch.num_rows();
4597
4598 let mut combined = vec![1.0f64; num_rows];
4600 for col_name in complement_cols {
4601 if let Ok(idx) = batch.schema().index_of(col_name) {
4602 let arr = batch
4603 .column(idx)
4604 .as_any()
4605 .downcast_ref::<Float64Array>()
4606 .ok_or_else(|| {
4607 datafusion::error::DataFusionError::Internal(format!(
4608 "Expected Float64 for complement column {col_name}"
4609 ))
4610 })?;
4611 for (i, val) in combined.iter_mut().enumerate().take(num_rows) {
4612 if !arr.is_null(i) {
4613 *val *= arr.value(i);
4614 }
4615 }
4616 }
4617 }
4618
4619 let final_prob: Vec<f64> = if let Some(prob_name) = prob_col {
4621 if let Ok(idx) = batch.schema().index_of(prob_name) {
4622 let arr = batch
4623 .column(idx)
4624 .as_any()
4625 .downcast_ref::<Float64Array>()
4626 .ok_or_else(|| {
4627 datafusion::error::DataFusionError::Internal(format!(
4628 "Expected Float64 for PROB column {prob_name}"
4629 ))
4630 })?;
4631 (0..num_rows)
4632 .map(|i| {
4633 if arr.is_null(i) {
4634 combined[i]
4635 } else {
4636 arr.value(i) * combined[i]
4637 }
4638 })
4639 .collect()
4640 } else {
4641 combined
4642 }
4643 } else {
4644 combined
4645 };
4646
4647 let new_prob_array: arrow_array::ArrayRef =
4648 std::sync::Arc::new(Float64Array::from(final_prob));
4649
4650 let mut fields = Vec::new();
4652 let mut columns = Vec::new();
4653
4654 for (idx, field) in batch.schema().fields().iter().enumerate() {
4655 if complement_cols.contains(field.name()) {
4656 continue;
4657 }
4658 if prob_col.is_some_and(|p| field.name() == p) {
4659 fields.push(field.clone());
4660 columns.push(new_prob_array.clone());
4661 } else {
4662 fields.push(field.clone());
4663 columns.push(batch.column(idx).clone());
4664 }
4665 }
4666
4667 let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4668 result.push(RecordBatch::try_new(schema, columns).map_err(arrow_err)?);
4669 }
4670
4671 Ok(result)
4672}
4673
4674fn update_derived_scan_handles(
4679 registry: &DerivedScanRegistry,
4680 states: &[FixpointState],
4681 current_rule_idx: usize,
4682 rules: &[FixpointRulePlan],
4683) {
4684 let current_rule_name = &rules[current_rule_idx].name;
4685
4686 for entry in ®istry.entries {
4687 let source_state_idx = rules.iter().position(|r| r.name == entry.rule_name);
4689 let Some(source_idx) = source_state_idx else {
4690 continue;
4691 };
4692
4693 let is_self = entry.rule_name == *current_rule_name;
4694 let data = if is_self && !rules[current_rule_idx].non_linear {
4695 states[source_idx].all_delta().to_vec()
4697 } else {
4698 states[source_idx].all_facts().to_vec()
4701 };
4702
4703 let data = if data.is_empty() || data.iter().all(|b| b.num_rows() == 0) {
4705 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
4706 } else {
4707 data
4708 };
4709
4710 let mut guard = entry.data.write();
4711 *guard = data;
4712 }
4713}
4714
4715pub struct DerivedScanExec {
4725 data: Arc<RwLock<Vec<RecordBatch>>>,
4726 schema: SchemaRef,
4727 properties: Arc<PlanProperties>,
4728}
4729
4730impl DerivedScanExec {
4731 pub fn new(data: Arc<RwLock<Vec<RecordBatch>>>, schema: SchemaRef) -> Self {
4732 let properties = compute_plan_properties(Arc::clone(&schema));
4733 Self {
4734 data,
4735 schema,
4736 properties,
4737 }
4738 }
4739}
4740
4741impl fmt::Debug for DerivedScanExec {
4742 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4743 f.debug_struct("DerivedScanExec")
4744 .field("schema", &self.schema)
4745 .finish()
4746 }
4747}
4748
4749impl DisplayAs for DerivedScanExec {
4750 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4751 write!(f, "DerivedScanExec")
4752 }
4753}
4754
4755impl ExecutionPlan for DerivedScanExec {
4756 fn name(&self) -> &str {
4757 "DerivedScanExec"
4758 }
4759 fn as_any(&self) -> &dyn Any {
4760 self
4761 }
4762 fn schema(&self) -> SchemaRef {
4763 Arc::clone(&self.schema)
4764 }
4765 fn properties(&self) -> &Arc<PlanProperties> {
4766 &self.properties
4767 }
4768 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
4769 vec![]
4770 }
4771 fn with_new_children(
4772 self: Arc<Self>,
4773 _children: Vec<Arc<dyn ExecutionPlan>>,
4774 ) -> DFResult<Arc<dyn ExecutionPlan>> {
4775 Ok(self)
4776 }
4777 fn execute(
4778 &self,
4779 _partition: usize,
4780 _context: Arc<TaskContext>,
4781 ) -> DFResult<SendableRecordBatchStream> {
4782 let batches = {
4783 let guard = self.data.read();
4784 if guard.is_empty() {
4785 vec![RecordBatch::new_empty(Arc::clone(&self.schema))]
4786 } else {
4787 guard
4793 .iter()
4794 .map(|b| {
4795 RecordBatch::try_new(Arc::clone(&self.schema), b.columns().to_vec())
4796 .map_err(|e| {
4797 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
4798 })
4799 })
4800 .collect::<DFResult<Vec<_>>>()?
4801 }
4802 };
4803 Ok(Box::pin(MemoryStream::try_new(
4804 batches,
4805 Arc::clone(&self.schema),
4806 None,
4807 )?))
4808 }
4809}
4810
4811struct InMemoryExec {
4820 batches: Vec<RecordBatch>,
4821 schema: SchemaRef,
4822 properties: Arc<PlanProperties>,
4823}
4824
4825impl InMemoryExec {
4826 fn new(batches: Vec<RecordBatch>, schema: SchemaRef) -> Self {
4827 let properties = compute_plan_properties(Arc::clone(&schema));
4828 Self {
4829 batches,
4830 schema,
4831 properties,
4832 }
4833 }
4834}
4835
4836impl fmt::Debug for InMemoryExec {
4837 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4838 f.debug_struct("InMemoryExec")
4839 .field("num_batches", &self.batches.len())
4840 .field("schema", &self.schema)
4841 .finish()
4842 }
4843}
4844
4845impl DisplayAs for InMemoryExec {
4846 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4847 write!(f, "InMemoryExec: batches={}", self.batches.len())
4848 }
4849}
4850
4851impl ExecutionPlan for InMemoryExec {
4852 fn name(&self) -> &str {
4853 "InMemoryExec"
4854 }
4855 fn as_any(&self) -> &dyn Any {
4856 self
4857 }
4858 fn schema(&self) -> SchemaRef {
4859 Arc::clone(&self.schema)
4860 }
4861 fn properties(&self) -> &Arc<PlanProperties> {
4862 &self.properties
4863 }
4864 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
4865 vec![]
4866 }
4867 fn with_new_children(
4868 self: Arc<Self>,
4869 _children: Vec<Arc<dyn ExecutionPlan>>,
4870 ) -> DFResult<Arc<dyn ExecutionPlan>> {
4871 Ok(self)
4872 }
4873 fn execute(
4874 &self,
4875 _partition: usize,
4876 _context: Arc<TaskContext>,
4877 ) -> DFResult<SendableRecordBatchStream> {
4878 Ok(Box::pin(MemoryStream::try_new(
4879 self.batches.clone(),
4880 Arc::clone(&self.schema),
4881 None,
4882 )?))
4883 }
4884}
4885
4886fn apply_having_filter(
4896 batches: Vec<RecordBatch>,
4897 having_exprs: &[Expr],
4898 schema: &SchemaRef,
4899 task_ctx: &Arc<TaskContext>,
4900) -> DFResult<Vec<RecordBatch>> {
4901 use arrow::compute::{and, filter_record_batch};
4902 use arrow_array::BooleanArray;
4903 use datafusion::common::DFSchema;
4904 use datafusion::logical_expr::LogicalPlanBuilder;
4905 use datafusion::logical_expr::execution_props::ExecutionProps;
4906 use datafusion::optimizer::AnalyzerRule;
4907 use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
4908 use datafusion::physical_expr::create_physical_expr;
4909
4910 if batches.is_empty() {
4911 return Ok(batches);
4912 }
4913
4914 let df_schema = DFSchema::try_from(schema.as_ref().clone()).map_err(|e| {
4916 datafusion::common::DataFusionError::Internal(format!("HAVING schema conversion: {e}"))
4917 })?;
4918
4919 let config = (**task_ctx.session_config().options()).clone();
4924 let props = ExecutionProps::new();
4925
4926 let physical_exprs: Vec<Arc<dyn datafusion::physical_expr::PhysicalExpr>> = having_exprs
4932 .iter()
4933 .map(|expr| {
4934 let df_expr = crate::query::df_expr::cypher_expr_to_df(expr, None).map_err(|e| {
4935 datafusion::common::DataFusionError::Internal(format!(
4936 "HAVING expression conversion: {e}"
4937 ))
4938 })?;
4939
4940 let empty = datafusion::logical_expr::LogicalPlan::EmptyRelation(
4944 datafusion::logical_expr::EmptyRelation {
4945 produce_one_row: false,
4946 schema: Arc::new(df_schema.clone()),
4947 },
4948 );
4949 let filter_plan = LogicalPlanBuilder::from(empty)
4950 .filter(df_expr.clone())?
4951 .build()?;
4952 let coerced_expr = match TypeCoercion::new().analyze(filter_plan, &config) {
4953 Ok(datafusion::logical_expr::LogicalPlan::Filter(f)) => f.predicate,
4954 _ => df_expr,
4955 };
4956
4957 create_physical_expr(&coerced_expr, &df_schema, &props)
4958 })
4959 .collect::<DFResult<Vec<_>>>()?;
4960
4961 let mut result = Vec::new();
4962 for batch in batches {
4963 let mut mask: Option<BooleanArray> = None;
4965 for phys_expr in &physical_exprs {
4966 let value = phys_expr.evaluate(&batch)?;
4967 let arr = value.into_array(batch.num_rows())?;
4968 let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().ok_or_else(|| {
4969 datafusion::common::DataFusionError::Internal(
4970 "HAVING condition must evaluate to boolean".into(),
4971 )
4972 })?;
4973 mask = Some(match mask {
4974 None => bool_arr.clone(),
4975 Some(prev) => and(&prev, bool_arr).map_err(arrow_err)?,
4976 });
4977 }
4978 if let Some(ref m) = mask {
4979 let filtered = filter_record_batch(&batch, m).map_err(arrow_err)?;
4980 if filtered.num_rows() > 0 {
4981 result.push(filtered);
4982 }
4983 } else {
4984 result.push(batch);
4985 }
4986 }
4987 Ok(result)
4988}
4989
4990#[allow(
4992 clippy::too_many_arguments,
4993 reason = "context bundle would be over-engineering for one call site"
4994)]
4995pub(crate) async fn apply_post_fixpoint_chain(
4996 facts: Vec<RecordBatch>,
4997 rule: &FixpointRulePlan,
4998 task_ctx: &Arc<TaskContext>,
4999 strict_probability_domain: bool,
5000 probability_epsilon: f64,
5001 semiring_kind: SemiringKind,
5002 provenance_tracker: Option<Arc<ProvenanceStore>>,
5003 top_k_proofs_k: usize,
5004 registry: Option<Arc<DerivedScanRegistry>>,
5005) -> DFResult<Vec<RecordBatch>> {
5006 if !rule.has_fold && !rule.has_best_by && !rule.has_priority && rule.having.is_empty() {
5007 return Ok(facts);
5008 }
5009
5010 let schema = facts
5015 .iter()
5016 .find(|b| b.num_rows() > 0)
5017 .map(|b| b.schema())
5018 .unwrap_or_else(|| Arc::clone(&rule.yield_schema));
5019
5020 let topk_k: Option<usize> = match semiring_kind {
5034 SemiringKind::TopKProofs { k } if k > 0 => Some(k as usize),
5035 _ => None,
5036 };
5037 let body_support_map: Option<Arc<HashMap<Vec<u8>, Vec<ProofTerm>>>> = if topk_k.is_some()
5038 && !rule.has_priority
5039 && let Some(registry) = registry.as_ref()
5040 {
5041 let mut map: HashMap<Vec<u8>, Vec<ProofTerm>> = HashMap::new();
5042 for batch in &facts {
5043 let all_indices: Vec<usize> = (0..batch.num_columns()).collect();
5044 for row_idx in 0..batch.num_rows() {
5045 let support = collect_is_ref_inputs_for_body_row(rule, batch, row_idx, registry);
5046 if support.is_empty() {
5047 continue;
5048 }
5049 let hash = fact_hash_key(batch, &all_indices, row_idx);
5050 map.insert(hash, support);
5051 }
5052 }
5053 if map.is_empty() {
5054 None
5055 } else {
5056 Some(Arc::new(map))
5057 }
5058 } else {
5059 None
5060 };
5061
5062 let input: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(facts, schema.clone()));
5063
5064 let key_column_indices: Vec<usize> = rule
5069 .key_column_indices
5070 .iter()
5071 .filter_map(|&i| {
5072 let name = rule.yield_schema.field(i).name();
5073 schema.index_of(name).ok()
5074 })
5075 .collect();
5076
5077 let current: Arc<dyn ExecutionPlan> = if rule.has_priority {
5081 let priority_schema = input.schema();
5082 let priority_idx = priority_schema.index_of("__priority").map_err(|_| {
5083 datafusion::common::DataFusionError::Internal(
5084 "PRIORITY rule missing __priority column".to_string(),
5085 )
5086 })?;
5087 Arc::new(PriorityExec::new(
5088 input,
5089 key_column_indices.clone(),
5090 priority_idx,
5091 ))
5092 } else {
5093 input
5094 };
5095
5096 let current: Arc<dyn ExecutionPlan> = if rule.has_fold && !rule.fold_bindings.is_empty() {
5098 Arc::new(FoldExec::new_with_topk(
5099 current,
5100 key_column_indices.clone(),
5101 rule.fold_bindings.clone(),
5102 strict_probability_domain,
5103 probability_epsilon,
5104 semiring_kind,
5105 provenance_tracker.clone(),
5106 topk_k.unwrap_or(top_k_proofs_k),
5107 body_support_map.clone(),
5108 ))
5109 } else {
5110 current
5111 };
5112
5113 let current: Arc<dyn ExecutionPlan> = if !rule.having.is_empty() {
5115 let batches = collect_all_partitions(¤t, Arc::clone(task_ctx)).await?;
5116 let filtered = apply_having_filter(batches, &rule.having, ¤t.schema(), task_ctx)?;
5117 if filtered.is_empty() {
5118 return Ok(filtered);
5119 }
5120 Arc::new(InMemoryExec::new(filtered, Arc::clone(¤t.schema())))
5121 } else {
5122 current
5123 };
5124
5125 let current: Arc<dyn ExecutionPlan> = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
5127 Arc::new(BestByExec::new(
5128 current,
5129 key_column_indices.clone(),
5130 rule.best_by_criteria.clone(),
5131 rule.deterministic,
5132 ))
5133 } else {
5134 current
5135 };
5136
5137 collect_all_partitions(¤t, Arc::clone(task_ctx)).await
5138}
5139
5140pub struct FixpointExec {
5149 rules: Vec<FixpointRulePlan>,
5150 max_iterations: usize,
5151 timeout: Duration,
5152 graph_ctx: Arc<GraphExecutionContext>,
5153 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5154 storage: Arc<StorageManager>,
5155 schema_info: Arc<UniSchema>,
5156 params: HashMap<String, Value>,
5157 derived_scan_registry: Arc<DerivedScanRegistry>,
5158 output_schema: SchemaRef,
5159 properties: Arc<PlanProperties>,
5160 metrics: ExecutionPlanMetricsSet,
5161 max_derived_bytes: usize,
5162 derivation_tracker: Option<Arc<ProvenanceStore>>,
5164 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5166 strict_probability_domain: bool,
5167 probability_epsilon: f64,
5168 exact_probability: bool,
5169 max_bdd_variables: usize,
5170 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5172 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5174 top_k_proofs: usize,
5176 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5178 semiring_kind: SemiringKind,
5180 classifier_registry: Arc<ClassifierRegistry>,
5184 classifier_cache: Option<Arc<ModelInvocationCache>>,
5190 #[allow(
5197 dead_code,
5198 reason = "boundary plumbing; read by EXPLAIN via LocyModelInvokeExec"
5199 )]
5200 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
5201 profile_collector: Option<Arc<LocyProfileCollector>>,
5205}
5206
5207impl fmt::Debug for FixpointExec {
5208 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5209 f.debug_struct("FixpointExec")
5210 .field("rules_count", &self.rules.len())
5211 .field("max_iterations", &self.max_iterations)
5212 .field("timeout", &self.timeout)
5213 .field("output_schema", &self.output_schema)
5214 .field("max_derived_bytes", &self.max_derived_bytes)
5215 .finish_non_exhaustive()
5216 }
5217}
5218
5219impl FixpointExec {
5220 #[expect(
5222 clippy::too_many_arguments,
5223 reason = "FixpointExec configuration needs all context"
5224 )]
5225 #[deprecated(
5226 note = "use `new_with_semiring_classifiers_and_cache` (or the lighter \
5227 `new_with_semiring_and_classifiers` / `new_with_semiring`) — \
5228 this legacy ctor defaults the semiring to AddMultProb and \
5229 ships no classifier registry, which the Phase B+ runtime needs \
5230 explicitly. To be removed after C0 Stage 2."
5231 )]
5232 pub fn new(
5233 rules: Vec<FixpointRulePlan>,
5234 max_iterations: usize,
5235 timeout: Duration,
5236 graph_ctx: Arc<GraphExecutionContext>,
5237 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5238 storage: Arc<StorageManager>,
5239 schema_info: Arc<UniSchema>,
5240 params: HashMap<String, Value>,
5241 derived_scan_registry: Arc<DerivedScanRegistry>,
5242 output_schema: SchemaRef,
5243 max_derived_bytes: usize,
5244 derivation_tracker: Option<Arc<ProvenanceStore>>,
5245 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5246 strict_probability_domain: bool,
5247 probability_epsilon: f64,
5248 exact_probability: bool,
5249 max_bdd_variables: usize,
5250 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5251 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5252 top_k_proofs: usize,
5253 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5254 ) -> Self {
5255 Self::new_with_semiring_and_classifiers(
5256 rules,
5257 max_iterations,
5258 timeout,
5259 graph_ctx,
5260 session_ctx,
5261 storage,
5262 schema_info,
5263 params,
5264 derived_scan_registry,
5265 output_schema,
5266 max_derived_bytes,
5267 derivation_tracker,
5268 iteration_counts,
5269 strict_probability_domain,
5270 probability_epsilon,
5271 exact_probability,
5272 max_bdd_variables,
5273 warnings_slot,
5274 approximate_slot,
5275 top_k_proofs,
5276 timeout_flag,
5277 SemiringKind::AddMultProb,
5278 Arc::new(ClassifierRegistry::new()),
5279 )
5280 }
5281
5282 #[expect(
5286 clippy::too_many_arguments,
5287 reason = "FixpointExec configuration needs all context"
5288 )]
5289 pub fn new_with_semiring(
5290 rules: Vec<FixpointRulePlan>,
5291 max_iterations: usize,
5292 timeout: Duration,
5293 graph_ctx: Arc<GraphExecutionContext>,
5294 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5295 storage: Arc<StorageManager>,
5296 schema_info: Arc<UniSchema>,
5297 params: HashMap<String, Value>,
5298 derived_scan_registry: Arc<DerivedScanRegistry>,
5299 output_schema: SchemaRef,
5300 max_derived_bytes: usize,
5301 derivation_tracker: Option<Arc<ProvenanceStore>>,
5302 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5303 strict_probability_domain: bool,
5304 probability_epsilon: f64,
5305 exact_probability: bool,
5306 max_bdd_variables: usize,
5307 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5308 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5309 top_k_proofs: usize,
5310 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5311 semiring_kind: SemiringKind,
5312 ) -> Self {
5313 Self::new_with_semiring_and_classifiers(
5314 rules,
5315 max_iterations,
5316 timeout,
5317 graph_ctx,
5318 session_ctx,
5319 storage,
5320 schema_info,
5321 params,
5322 derived_scan_registry,
5323 output_schema,
5324 max_derived_bytes,
5325 derivation_tracker,
5326 iteration_counts,
5327 strict_probability_domain,
5328 probability_epsilon,
5329 exact_probability,
5330 max_bdd_variables,
5331 warnings_slot,
5332 approximate_slot,
5333 top_k_proofs,
5334 timeout_flag,
5335 semiring_kind,
5336 Arc::new(ClassifierRegistry::new()),
5337 )
5338 }
5339
5340 #[expect(
5344 clippy::too_many_arguments,
5345 reason = "FixpointExec configuration needs all context"
5346 )]
5347 pub fn new_with_semiring_and_classifiers(
5348 rules: Vec<FixpointRulePlan>,
5349 max_iterations: usize,
5350 timeout: Duration,
5351 graph_ctx: Arc<GraphExecutionContext>,
5352 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5353 storage: Arc<StorageManager>,
5354 schema_info: Arc<UniSchema>,
5355 params: HashMap<String, Value>,
5356 derived_scan_registry: Arc<DerivedScanRegistry>,
5357 output_schema: SchemaRef,
5358 max_derived_bytes: usize,
5359 derivation_tracker: Option<Arc<ProvenanceStore>>,
5360 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5361 strict_probability_domain: bool,
5362 probability_epsilon: f64,
5363 exact_probability: bool,
5364 max_bdd_variables: usize,
5365 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5366 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5367 top_k_proofs: usize,
5368 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5369 semiring_kind: SemiringKind,
5370 classifier_registry: Arc<ClassifierRegistry>,
5371 ) -> Self {
5372 Self::new_with_semiring_classifiers_and_cache(
5373 rules,
5374 max_iterations,
5375 timeout,
5376 graph_ctx,
5377 session_ctx,
5378 storage,
5379 schema_info,
5380 params,
5381 derived_scan_registry,
5382 output_schema,
5383 max_derived_bytes,
5384 derivation_tracker,
5385 iteration_counts,
5386 strict_probability_domain,
5387 probability_epsilon,
5388 exact_probability,
5389 max_bdd_variables,
5390 warnings_slot,
5391 approximate_slot,
5392 top_k_proofs,
5393 timeout_flag,
5394 semiring_kind,
5395 classifier_registry,
5396 None,
5397 None,
5398 )
5399 }
5400
5401 #[expect(
5405 clippy::too_many_arguments,
5406 reason = "FixpointExec configuration needs all context"
5407 )]
5408 pub fn new_with_semiring_classifiers_and_cache(
5409 rules: Vec<FixpointRulePlan>,
5410 max_iterations: usize,
5411 timeout: Duration,
5412 graph_ctx: Arc<GraphExecutionContext>,
5413 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5414 storage: Arc<StorageManager>,
5415 schema_info: Arc<UniSchema>,
5416 params: HashMap<String, Value>,
5417 derived_scan_registry: Arc<DerivedScanRegistry>,
5418 output_schema: SchemaRef,
5419 max_derived_bytes: usize,
5420 derivation_tracker: Option<Arc<ProvenanceStore>>,
5421 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5422 strict_probability_domain: bool,
5423 probability_epsilon: f64,
5424 exact_probability: bool,
5425 max_bdd_variables: usize,
5426 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5427 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5428 top_k_proofs: usize,
5429 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5430 semiring_kind: SemiringKind,
5431 classifier_registry: Arc<ClassifierRegistry>,
5432 classifier_cache: Option<Arc<ModelInvocationCache>>,
5433 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
5434 ) -> Self {
5435 let properties = compute_plan_properties(Arc::clone(&output_schema));
5436 Self {
5437 rules,
5438 max_iterations,
5439 timeout,
5440 graph_ctx,
5441 session_ctx,
5442 storage,
5443 schema_info,
5444 params,
5445 derived_scan_registry,
5446 output_schema,
5447 properties,
5448 metrics: ExecutionPlanMetricsSet::new(),
5449 max_derived_bytes,
5450 derivation_tracker,
5451 iteration_counts,
5452 strict_probability_domain,
5453 probability_epsilon,
5454 exact_probability,
5455 max_bdd_variables,
5456 warnings_slot,
5457 approximate_slot,
5458 top_k_proofs,
5459 timeout_flag,
5460 semiring_kind,
5461 classifier_registry,
5462 classifier_cache,
5463 classifier_provenance_store,
5464 profile_collector: None,
5465 }
5466 }
5467
5468 pub fn iteration_counts(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
5470 Arc::clone(&self.iteration_counts)
5471 }
5472
5473 pub fn set_profile_collector(&mut self, collector: Arc<LocyProfileCollector>) {
5479 self.profile_collector = Some(collector);
5480 }
5481}
5482
5483impl DisplayAs for FixpointExec {
5484 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5485 write!(
5486 f,
5487 "FixpointExec: rules=[{}], max_iter={}, timeout={:?}",
5488 self.rules
5489 .iter()
5490 .map(|r| r.name.as_str())
5491 .collect::<Vec<_>>()
5492 .join(", "),
5493 self.max_iterations,
5494 self.timeout,
5495 )
5496 }
5497}
5498
5499impl ExecutionPlan for FixpointExec {
5500 fn name(&self) -> &str {
5501 "FixpointExec"
5502 }
5503
5504 fn as_any(&self) -> &dyn Any {
5505 self
5506 }
5507
5508 fn schema(&self) -> SchemaRef {
5509 Arc::clone(&self.output_schema)
5510 }
5511
5512 fn properties(&self) -> &Arc<PlanProperties> {
5513 &self.properties
5514 }
5515
5516 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
5517 vec![]
5519 }
5520
5521 fn with_new_children(
5522 self: Arc<Self>,
5523 children: Vec<Arc<dyn ExecutionPlan>>,
5524 ) -> DFResult<Arc<dyn ExecutionPlan>> {
5525 if !children.is_empty() {
5526 return Err(datafusion::error::DataFusionError::Plan(
5527 "FixpointExec has no children".to_string(),
5528 ));
5529 }
5530 Ok(self)
5531 }
5532
5533 fn execute(
5534 &self,
5535 partition: usize,
5536 _context: Arc<TaskContext>,
5537 ) -> DFResult<SendableRecordBatchStream> {
5538 let metrics = BaselineMetrics::new(&self.metrics, partition);
5539
5540 let rules = self
5542 .rules
5543 .iter()
5544 .map(|r| {
5545 FixpointRulePlan {
5549 name: r.name.clone(),
5550 clauses: r
5551 .clauses
5552 .iter()
5553 .map(|c| FixpointClausePlan {
5554 body_logical: c.body_logical.clone(),
5555 is_ref_bindings: c.is_ref_bindings.clone(),
5556 priority: c.priority,
5557 along_bindings: c.along_bindings.clone(),
5558 model_invocations: c.model_invocations.clone(),
5559 })
5560 .collect(),
5561 yield_schema: Arc::clone(&r.yield_schema),
5562 key_column_indices: r.key_column_indices.clone(),
5563 priority: r.priority,
5564 has_fold: r.has_fold,
5565 fold_bindings: r.fold_bindings.clone(),
5566 having: r.having.clone(),
5567 has_best_by: r.has_best_by,
5568 best_by_criteria: r.best_by_criteria.clone(),
5569 has_priority: r.has_priority,
5570 deterministic: r.deterministic,
5571 prob_column_name: r.prob_column_name.clone(),
5572 non_linear: r.non_linear,
5573 }
5574 })
5575 .collect();
5576
5577 let max_iterations = self.max_iterations;
5578 let timeout = self.timeout;
5579 let graph_ctx = Arc::clone(&self.graph_ctx);
5580 let session_ctx = Arc::clone(&self.session_ctx);
5581 let storage = Arc::clone(&self.storage);
5582 let schema_info = Arc::clone(&self.schema_info);
5583 let params = self.params.clone();
5584 let registry = Arc::clone(&self.derived_scan_registry);
5585 let output_schema = Arc::clone(&self.output_schema);
5586 let max_derived_bytes = self.max_derived_bytes;
5587 let derivation_tracker = self.derivation_tracker.clone();
5588 let iteration_counts = Arc::clone(&self.iteration_counts);
5589 let strict_probability_domain = self.strict_probability_domain;
5590 let probability_epsilon = self.probability_epsilon;
5591 let exact_probability = self.exact_probability;
5592 let max_bdd_variables = self.max_bdd_variables;
5593 let warnings_slot = Arc::clone(&self.warnings_slot);
5594 let approximate_slot = Arc::clone(&self.approximate_slot);
5595 let top_k_proofs = self.top_k_proofs;
5596 let timeout_flag = Arc::clone(&self.timeout_flag);
5597 let semiring_kind = self.semiring_kind;
5598 let classifier_registry = Arc::clone(&self.classifier_registry);
5599 let classifier_cache = self.classifier_cache.as_ref().map(Arc::clone);
5600 let classifier_provenance_store = self.classifier_provenance_store.as_ref().map(Arc::clone);
5601 let profile_collector = self.profile_collector.as_ref().map(Arc::clone);
5602
5603 let fut = async move {
5604 run_fixpoint_loop(
5605 rules,
5606 max_iterations,
5607 timeout,
5608 graph_ctx,
5609 session_ctx,
5610 storage,
5611 schema_info,
5612 params,
5613 registry,
5614 output_schema,
5615 max_derived_bytes,
5616 derivation_tracker,
5617 iteration_counts,
5618 strict_probability_domain,
5619 probability_epsilon,
5620 exact_probability,
5621 max_bdd_variables,
5622 warnings_slot,
5623 approximate_slot,
5624 top_k_proofs,
5625 timeout_flag,
5626 semiring_kind,
5627 classifier_registry,
5628 classifier_cache,
5629 classifier_provenance_store,
5630 profile_collector,
5631 )
5632 .await
5633 };
5634
5635 Ok(Box::pin(FixpointStream {
5636 state: FixpointStreamState::Running(Box::pin(fut)),
5637 schema: Arc::clone(&self.output_schema),
5638 metrics,
5639 }))
5640 }
5641
5642 fn metrics(&self) -> Option<MetricsSet> {
5643 Some(self.metrics.clone_inner())
5644 }
5645}
5646
5647enum FixpointStreamState {
5652 Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
5654 Emitting(Vec<RecordBatch>, usize),
5656 Done,
5658}
5659
5660struct FixpointStream {
5661 state: FixpointStreamState,
5662 schema: SchemaRef,
5663 metrics: BaselineMetrics,
5664}
5665
5666impl Stream for FixpointStream {
5667 type Item = DFResult<RecordBatch>;
5668
5669 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
5670 let this = self.get_mut();
5671 let metrics = this.metrics.clone();
5672 let _timer = metrics.elapsed_compute().timer();
5673 loop {
5674 match &mut this.state {
5675 FixpointStreamState::Running(fut) => match fut.as_mut().poll(cx) {
5676 Poll::Ready(Ok(batches)) => {
5677 if batches.is_empty() {
5678 this.state = FixpointStreamState::Done;
5679 return Poll::Ready(None);
5680 }
5681 this.state = FixpointStreamState::Emitting(batches, 0);
5682 }
5684 Poll::Ready(Err(e)) => {
5685 this.state = FixpointStreamState::Done;
5686 return Poll::Ready(Some(Err(e)));
5687 }
5688 Poll::Pending => return Poll::Pending,
5689 },
5690 FixpointStreamState::Emitting(batches, idx) => {
5691 if *idx >= batches.len() {
5692 this.state = FixpointStreamState::Done;
5693 return Poll::Ready(None);
5694 }
5695 let batch = batches[*idx].clone();
5696 *idx += 1;
5697 this.metrics.record_output(batch.num_rows());
5698 return Poll::Ready(Some(Ok(batch)));
5699 }
5700 FixpointStreamState::Done => return Poll::Ready(None),
5701 }
5702 }
5703 }
5704}
5705
5706impl RecordBatchStream for FixpointStream {
5707 fn schema(&self) -> SchemaRef {
5708 Arc::clone(&self.schema)
5709 }
5710}
5711
5712#[cfg(test)]
5717mod tests {
5718 use super::*;
5719 use arrow_array::{Float64Array, Int64Array, StringArray};
5720 use arrow_schema::{DataType, Field, Schema};
5721
5722 fn test_schema() -> SchemaRef {
5723 Arc::new(Schema::new(vec![
5724 Field::new("name", DataType::Utf8, true),
5725 Field::new("value", DataType::Int64, true),
5726 ]))
5727 }
5728
5729 fn make_batch(names: &[&str], values: &[i64]) -> RecordBatch {
5730 RecordBatch::try_new(
5731 test_schema(),
5732 vec![
5733 Arc::new(StringArray::from(
5734 names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
5735 )),
5736 Arc::new(Int64Array::from(values.to_vec())),
5737 ],
5738 )
5739 .unwrap()
5740 }
5741
5742 #[tokio::test]
5745 async fn test_fixpoint_state_empty_facts_adds_all() {
5746 let schema = test_schema();
5747 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5748
5749 let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
5750 let changed = state.merge_delta(vec![batch], None).await.unwrap();
5751
5752 assert!(changed);
5753 assert_eq!(state.all_facts().len(), 1);
5754 assert_eq!(state.all_facts()[0].num_rows(), 3);
5755 assert_eq!(state.all_delta().len(), 1);
5756 assert_eq!(state.all_delta()[0].num_rows(), 3);
5757 }
5758
5759 #[tokio::test]
5760 async fn test_fixpoint_state_exact_duplicates_excluded() {
5761 let schema = test_schema();
5762 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5763
5764 let batch1 = make_batch(&["a", "b"], &[1, 2]);
5765 state.merge_delta(vec![batch1], None).await.unwrap();
5766
5767 let batch2 = make_batch(&["a", "b"], &[1, 2]);
5769 let changed = state.merge_delta(vec![batch2], None).await.unwrap();
5770 assert!(!changed);
5771 assert!(
5772 state.all_delta().is_empty() || state.all_delta().iter().all(|b| b.num_rows() == 0)
5773 );
5774 }
5775
5776 #[tokio::test]
5777 async fn test_fixpoint_state_partial_overlap() {
5778 let schema = test_schema();
5779 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5780
5781 let batch1 = make_batch(&["a", "b"], &[1, 2]);
5782 state.merge_delta(vec![batch1], None).await.unwrap();
5783
5784 let batch2 = make_batch(&["a", "c"], &[1, 3]);
5786 let changed = state.merge_delta(vec![batch2], None).await.unwrap();
5787 assert!(changed);
5788
5789 let delta_rows: usize = state.all_delta().iter().map(|b| b.num_rows()).sum();
5791 assert_eq!(delta_rows, 1);
5792
5793 let total_rows: usize = state.all_facts().iter().map(|b| b.num_rows()).sum();
5795 assert_eq!(total_rows, 3);
5796 }
5797
5798 #[tokio::test]
5799 async fn test_fixpoint_state_convergence() {
5800 let schema = test_schema();
5801 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5802
5803 let batch = make_batch(&["a"], &[1]);
5804 state.merge_delta(vec![batch], None).await.unwrap();
5805
5806 let changed = state.merge_delta(vec![], None).await.unwrap();
5808 assert!(!changed);
5809 assert!(state.is_converged());
5810 }
5811
5812 #[test]
5815 fn test_row_dedup_persistent_across_calls() {
5816 let schema = test_schema();
5819 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5820
5821 let batch1 = make_batch(&["a", "b"], &[1, 2]);
5822 let delta1 = rd.compute_delta(&[batch1], &schema).unwrap();
5823 let rows1: usize = delta1.iter().map(|b| b.num_rows()).sum();
5825 assert_eq!(rows1, 2);
5826
5827 let batch2 = make_batch(&["a", "b"], &[1, 2]);
5829 let delta2 = rd.compute_delta(&[batch2], &schema).unwrap();
5830 let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
5831 assert_eq!(rows2, 0);
5832
5833 let batch3 = make_batch(&["a", "c"], &[1, 3]);
5835 let delta3 = rd.compute_delta(&[batch3], &schema).unwrap();
5836 let rows3: usize = delta3.iter().map(|b| b.num_rows()).sum();
5837 assert_eq!(rows3, 1);
5838 }
5839
5840 #[test]
5841 fn test_row_dedup_null_handling() {
5842 use arrow_array::StringArray;
5843 use arrow_schema::{DataType, Field, Schema};
5844
5845 let schema: SchemaRef = Arc::new(Schema::new(vec![
5846 Field::new("a", DataType::Utf8, true),
5847 Field::new("b", DataType::Int64, true),
5848 ]));
5849 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5850
5851 let batch_nulls = RecordBatch::try_new(
5853 Arc::clone(&schema),
5854 vec![
5855 Arc::new(StringArray::from(vec![None::<&str>, None::<&str>])),
5856 Arc::new(arrow_array::Int64Array::from(vec![1i64, 1i64])),
5857 ],
5858 )
5859 .unwrap();
5860 let delta = rd.compute_delta(&[batch_nulls], &schema).unwrap();
5861 let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
5862 assert_eq!(rows, 1, "two identical NULL rows should be deduped to one");
5863
5864 let batch_diff = RecordBatch::try_new(
5866 Arc::clone(&schema),
5867 vec![
5868 Arc::new(StringArray::from(vec![None::<&str>])),
5869 Arc::new(arrow_array::Int64Array::from(vec![2i64])),
5870 ],
5871 )
5872 .unwrap();
5873 let delta2 = rd.compute_delta(&[batch_diff], &schema).unwrap();
5874 let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
5875 assert_eq!(rows2, 1, "(NULL, 2) is distinct from (NULL, 1)");
5876 }
5877
5878 #[test]
5879 fn test_row_dedup_within_candidate_dedup() {
5880 let schema = test_schema();
5882 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5883
5884 let batch = make_batch(&["a", "a", "b"], &[1, 1, 2]);
5886 let delta = rd.compute_delta(&[batch], &schema).unwrap();
5887 let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
5888 assert_eq!(rows, 2, "within-batch dup should be collapsed: a:1, b:2");
5889 }
5890
5891 #[test]
5894 fn test_round_float_columns_near_duplicates() {
5895 let schema = Arc::new(Schema::new(vec![
5896 Field::new("name", DataType::Utf8, true),
5897 Field::new("dist", DataType::Float64, true),
5898 ]));
5899 let batch = RecordBatch::try_new(
5900 Arc::clone(&schema),
5901 vec![
5902 Arc::new(StringArray::from(vec![Some("a"), Some("a")])),
5903 Arc::new(Float64Array::from(vec![1.0000000000001, 1.0000000000002])),
5904 ],
5905 )
5906 .unwrap();
5907
5908 let rounded = round_float_columns(&[batch]);
5909 assert_eq!(rounded.len(), 1);
5910 let col = rounded[0]
5911 .column(1)
5912 .as_any()
5913 .downcast_ref::<Float64Array>()
5914 .unwrap();
5915 assert_eq!(col.value(0), col.value(1));
5917 }
5918
5919 #[test]
5922 fn test_registry_write_read_round_trip() {
5923 let schema = test_schema();
5924 let data = Arc::new(RwLock::new(Vec::new()));
5925 let mut reg = DerivedScanRegistry::new();
5926 reg.add(DerivedScanEntry {
5927 scan_index: 0,
5928 rule_name: "reachable".into(),
5929 is_self_ref: true,
5930 data: Arc::clone(&data),
5931 schema: Arc::clone(&schema),
5932 });
5933
5934 let batch = make_batch(&["x"], &[42]);
5935 reg.write_data(0, vec![batch.clone()]);
5936
5937 let entry = reg.get(0).unwrap();
5938 let guard = entry.data.read();
5939 assert_eq!(guard.len(), 1);
5940 assert_eq!(guard[0].num_rows(), 1);
5941 }
5942
5943 #[test]
5944 fn test_registry_entries_for_rule() {
5945 let schema = test_schema();
5946 let mut reg = DerivedScanRegistry::new();
5947 reg.add(DerivedScanEntry {
5948 scan_index: 0,
5949 rule_name: "r1".into(),
5950 is_self_ref: true,
5951 data: Arc::new(RwLock::new(Vec::new())),
5952 schema: Arc::clone(&schema),
5953 });
5954 reg.add(DerivedScanEntry {
5955 scan_index: 1,
5956 rule_name: "r2".into(),
5957 is_self_ref: false,
5958 data: Arc::new(RwLock::new(Vec::new())),
5959 schema: Arc::clone(&schema),
5960 });
5961 reg.add(DerivedScanEntry {
5962 scan_index: 2,
5963 rule_name: "r1".into(),
5964 is_self_ref: false,
5965 data: Arc::new(RwLock::new(Vec::new())),
5966 schema: Arc::clone(&schema),
5967 });
5968
5969 assert_eq!(reg.entries_for_rule("r1").len(), 2);
5970 assert_eq!(reg.entries_for_rule("r2").len(), 1);
5971 assert_eq!(reg.entries_for_rule("r3").len(), 0);
5972 }
5973
5974 #[test]
5977 fn test_monotonic_agg_update_and_stability() {
5978 let bindings = vec![MonotonicFoldBinding {
5979 fold_name: "total".into(),
5980 aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::SumAgg),
5981 input_col_index: 1,
5982 input_col_name: None,
5983 }];
5984 let mut agg = MonotonicAggState::new(bindings);
5985
5986 let batch = make_batch(&["a"], &[10]);
5988 agg.snapshot();
5989 let changed = agg
5990 .update(&[0], &[batch], false, SemiringKind::AddMultProb)
5991 .unwrap();
5992 assert!(changed);
5993 assert!(!agg.is_stable()); agg.snapshot();
5997 let changed = agg
5998 .update(&[0], &[], false, SemiringKind::AddMultProb)
5999 .unwrap();
6000 assert!(!changed);
6001 assert!(agg.is_stable());
6002 }
6003
6004 #[tokio::test]
6007 async fn test_memory_limit_exceeded() {
6008 let schema = test_schema();
6009 let mut state = FixpointState::new("test".into(), schema, vec![0], 1, None, false);
6011
6012 let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
6013 let result = state.merge_delta(vec![batch], None).await;
6014 assert!(result.is_err());
6015 let err = result.unwrap_err().to_string();
6016 assert!(err.contains("memory limit"), "Error was: {}", err);
6017 }
6018
6019 #[tokio::test]
6022 async fn test_fixpoint_stream_emitting() {
6023 use futures::StreamExt;
6024
6025 let schema = test_schema();
6026 let batch1 = make_batch(&["a"], &[1]);
6027 let batch2 = make_batch(&["b"], &[2]);
6028
6029 let metrics = ExecutionPlanMetricsSet::new();
6030 let baseline = BaselineMetrics::new(&metrics, 0);
6031
6032 let mut stream = FixpointStream {
6033 state: FixpointStreamState::Emitting(vec![batch1, batch2], 0),
6034 schema,
6035 metrics: baseline,
6036 };
6037
6038 let stream = Pin::new(&mut stream);
6039 let batches: Vec<RecordBatch> = stream.filter_map(|r| async { r.ok() }).collect().await;
6040
6041 assert_eq!(batches.len(), 2);
6042 assert_eq!(batches[0].num_rows(), 1);
6043 assert_eq!(batches[1].num_rows(), 1);
6044 }
6045
6046 fn make_f64_batch(names: &[&str], values: &[f64]) -> RecordBatch {
6049 let schema = Arc::new(Schema::new(vec![
6050 Field::new("name", DataType::Utf8, true),
6051 Field::new("value", DataType::Float64, true),
6052 ]));
6053 RecordBatch::try_new(
6054 schema,
6055 vec![
6056 Arc::new(StringArray::from(
6057 names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
6058 )),
6059 Arc::new(Float64Array::from(values.to_vec())),
6060 ],
6061 )
6062 .unwrap()
6063 }
6064
6065 fn make_nor_binding() -> Vec<MonotonicFoldBinding> {
6066 vec![MonotonicFoldBinding {
6067 fold_name: "prob".into(),
6068 aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::MnorAgg),
6069 input_col_index: 1,
6070 input_col_name: None,
6071 }]
6072 }
6073
6074 fn make_prod_binding() -> Vec<MonotonicFoldBinding> {
6075 vec![MonotonicFoldBinding {
6076 fold_name: "prob".into(),
6077 aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::MprodAgg),
6078 input_col_index: 1,
6079 input_col_name: None,
6080 }]
6081 }
6082
6083 fn acc_key(name: &str) -> (Vec<ScalarKey>, String) {
6084 (vec![ScalarKey::Utf8(name.to_string())], "prob".to_string())
6085 }
6086
6087 #[test]
6088 fn test_monotonic_nor_first_update() {
6089 let mut agg = MonotonicAggState::new(make_nor_binding());
6090 let batch = make_f64_batch(&["a"], &[0.3]);
6091 let changed = agg
6092 .update(&[0], &[batch], false, SemiringKind::AddMultProb)
6093 .unwrap();
6094 assert!(changed);
6095 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6096 assert!((val - 0.3).abs() < 1e-10, "expected 0.3, got {}", val);
6097 }
6098
6099 #[test]
6100 fn test_monotonic_nor_two_updates() {
6101 let mut agg = MonotonicAggState::new(make_nor_binding());
6103 let batch1 = make_f64_batch(&["a"], &[0.3]);
6104 agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6105 .unwrap();
6106 let batch2 = make_f64_batch(&["a"], &[0.5]);
6107 agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6108 .unwrap();
6109 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6110 assert!((val - 0.65).abs() < 1e-10, "expected 0.65, got {}", val);
6111 }
6112
6113 #[test]
6114 fn test_monotonic_prod_first_update() {
6115 let mut agg = MonotonicAggState::new(make_prod_binding());
6116 let batch = make_f64_batch(&["a"], &[0.6]);
6117 let changed = agg
6118 .update(&[0], &[batch], false, SemiringKind::AddMultProb)
6119 .unwrap();
6120 assert!(changed);
6121 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6122 assert!((val - 0.6).abs() < 1e-10, "expected 0.6, got {}", val);
6123 }
6124
6125 #[test]
6126 fn test_monotonic_prod_two_updates() {
6127 let mut agg = MonotonicAggState::new(make_prod_binding());
6129 let batch1 = make_f64_batch(&["a"], &[0.6]);
6130 agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6131 .unwrap();
6132 let batch2 = make_f64_batch(&["a"], &[0.8]);
6133 agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6134 .unwrap();
6135 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6136 assert!((val - 0.48).abs() < 1e-10, "expected 0.48, got {}", val);
6137 }
6138
6139 #[test]
6140 fn test_monotonic_nor_stability() {
6141 let mut agg = MonotonicAggState::new(make_nor_binding());
6142 let batch = make_f64_batch(&["a"], &[0.3]);
6143 agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6144 .unwrap();
6145 agg.snapshot();
6146 let changed = agg
6147 .update(&[0], &[], false, SemiringKind::AddMultProb)
6148 .unwrap();
6149 assert!(!changed);
6150 assert!(agg.is_stable());
6151 }
6152
6153 #[test]
6154 fn test_monotonic_prod_stability() {
6155 let mut agg = MonotonicAggState::new(make_prod_binding());
6156 let batch = make_f64_batch(&["a"], &[0.6]);
6157 agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6158 .unwrap();
6159 agg.snapshot();
6160 let changed = agg
6161 .update(&[0], &[], false, SemiringKind::AddMultProb)
6162 .unwrap();
6163 assert!(!changed);
6164 assert!(agg.is_stable());
6165 }
6166
6167 #[test]
6168 fn test_monotonic_nor_multi_group() {
6169 let mut agg = MonotonicAggState::new(make_nor_binding());
6171 let batch1 = make_f64_batch(&["a", "b"], &[0.3, 0.5]);
6172 agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6173 .unwrap();
6174 let batch2 = make_f64_batch(&["a", "b"], &[0.5, 0.2]);
6175 agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6176 .unwrap();
6177
6178 let val_a = agg.get_accumulator(&acc_key("a")).unwrap();
6179 let val_b = agg.get_accumulator(&acc_key("b")).unwrap();
6180 assert!(
6181 (val_a - 0.65).abs() < 1e-10,
6182 "expected a=0.65, got {}",
6183 val_a
6184 );
6185 assert!((val_b - 0.6).abs() < 1e-10, "expected b=0.6, got {}", val_b);
6186 }
6187
6188 #[test]
6189 fn test_monotonic_prod_zero_absorbing() {
6190 let mut agg = MonotonicAggState::new(make_prod_binding());
6192 let batch1 = make_f64_batch(&["a"], &[0.5]);
6193 agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6194 .unwrap();
6195 let batch2 = make_f64_batch(&["a"], &[0.0]);
6196 agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6197 .unwrap();
6198
6199 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6200 assert!((val - 0.0).abs() < 1e-10, "expected 0.0, got {}", val);
6201
6202 agg.snapshot();
6204 let batch3 = make_f64_batch(&["a"], &[0.5]);
6205 let changed = agg
6206 .update(&[0], &[batch3], false, SemiringKind::AddMultProb)
6207 .unwrap();
6208 assert!(!changed);
6209 assert!(agg.is_stable());
6210 }
6211
6212 #[test]
6213 fn test_monotonic_nor_clamping() {
6214 let mut agg = MonotonicAggState::new(make_nor_binding());
6216 let batch = make_f64_batch(&["a"], &[1.5]);
6217 agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6218 .unwrap();
6219 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6220 assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
6221 }
6222
6223 #[test]
6224 fn test_monotonic_nor_absorbing() {
6225 let mut agg = MonotonicAggState::new(make_nor_binding());
6227 let batch1 = make_f64_batch(&["a"], &[0.3]);
6228 agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6229 .unwrap();
6230 let batch2 = make_f64_batch(&["a"], &[1.0]);
6231 agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6232 .unwrap();
6233 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6234 assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
6235 }
6236
6237 #[test]
6240 fn test_monotonic_agg_strict_nor_rejects() {
6241 let mut agg = MonotonicAggState::new(make_nor_binding());
6242 let batch = make_f64_batch(&["a"], &[1.5]);
6243 let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6244 assert!(result.is_err());
6245 let err = result.unwrap_err().to_string();
6246 assert!(
6247 err.contains("strict_probability_domain"),
6248 "Expected strict error, got: {}",
6249 err
6250 );
6251 }
6252
6253 #[test]
6254 fn test_monotonic_agg_strict_prod_rejects() {
6255 let mut agg = MonotonicAggState::new(make_prod_binding());
6256 let batch = make_f64_batch(&["a"], &[2.0]);
6257 let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6258 assert!(result.is_err());
6259 let err = result.unwrap_err().to_string();
6260 assert!(
6261 err.contains("strict_probability_domain"),
6262 "Expected strict error, got: {}",
6263 err
6264 );
6265 }
6266
6267 #[test]
6268 fn test_monotonic_agg_strict_accepts_valid() {
6269 let mut agg = MonotonicAggState::new(make_nor_binding());
6270 let batch = make_f64_batch(&["a"], &[0.5]);
6271 let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6272 assert!(result.is_ok());
6273 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6274 assert!((val - 0.5).abs() < 1e-10, "expected 0.5, got {}", val);
6275 }
6276
6277 fn make_vid_prob_batch(vids: &[u64], probs: &[f64]) -> RecordBatch {
6280 use arrow_array::UInt64Array;
6281 let schema = Arc::new(Schema::new(vec![
6282 Field::new("vid", DataType::UInt64, true),
6283 Field::new("prob", DataType::Float64, true),
6284 ]));
6285 RecordBatch::try_new(
6286 schema,
6287 vec![
6288 Arc::new(UInt64Array::from(vids.to_vec())),
6289 Arc::new(Float64Array::from(probs.to_vec())),
6290 ],
6291 )
6292 .unwrap()
6293 }
6294
6295 #[test]
6296 fn test_prob_complement_basic() {
6297 let body = make_vid_prob_batch(&[1, 2], &[0.9, 0.8]);
6299 let neg = make_vid_prob_batch(&[1], &[0.7]);
6300 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6301 let result = apply_prob_complement_composite(
6302 vec![body],
6303 &[neg],
6304 &join_cols,
6305 "prob",
6306 "__complement_0",
6307 )
6308 .unwrap();
6309 assert_eq!(result.len(), 1);
6310 let batch = &result[0];
6311 let complement = batch
6312 .column_by_name("__complement_0")
6313 .unwrap()
6314 .as_any()
6315 .downcast_ref::<Float64Array>()
6316 .unwrap();
6317 assert!(
6319 (complement.value(0) - 0.3).abs() < 1e-10,
6320 "expected 0.3, got {}",
6321 complement.value(0)
6322 );
6323 assert!(
6325 (complement.value(1) - 1.0).abs() < 1e-10,
6326 "expected 1.0, got {}",
6327 complement.value(1)
6328 );
6329 }
6330
6331 #[test]
6332 fn test_prob_complement_noisy_or_duplicates() {
6333 let body = make_vid_prob_batch(&[1], &[0.9]);
6337 let neg = make_vid_prob_batch(&[1, 1], &[0.3, 0.5]);
6338 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6339 let result = apply_prob_complement_composite(
6340 vec![body],
6341 &[neg],
6342 &join_cols,
6343 "prob",
6344 "__complement_0",
6345 )
6346 .unwrap();
6347 let batch = &result[0];
6348 let complement = batch
6349 .column_by_name("__complement_0")
6350 .unwrap()
6351 .as_any()
6352 .downcast_ref::<Float64Array>()
6353 .unwrap();
6354 assert!(
6355 (complement.value(0) - 0.35).abs() < 1e-10,
6356 "expected 0.35, got {}",
6357 complement.value(0)
6358 );
6359 }
6360
6361 #[test]
6362 fn test_prob_complement_empty_neg() {
6363 let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
6365 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6366 let result =
6367 apply_prob_complement_composite(vec![body], &[], &join_cols, "prob", "__complement_0")
6368 .unwrap();
6369 let batch = &result[0];
6370 let complement = batch
6371 .column_by_name("__complement_0")
6372 .unwrap()
6373 .as_any()
6374 .downcast_ref::<Float64Array>()
6375 .unwrap();
6376 for i in 0..2 {
6377 assert!(
6378 (complement.value(i) - 1.0).abs() < 1e-10,
6379 "row {}: expected 1.0, got {}",
6380 i,
6381 complement.value(i)
6382 );
6383 }
6384 }
6385
6386 #[test]
6387 fn test_anti_join_basic() {
6388 use arrow_array::UInt64Array;
6390 let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
6391 let neg = make_vid_prob_batch(&[2], &[0.0]);
6392 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6393 let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
6394 assert_eq!(result.len(), 1);
6395 let batch = &result[0];
6396 assert_eq!(batch.num_rows(), 2);
6397 let vids = batch
6398 .column_by_name("vid")
6399 .unwrap()
6400 .as_any()
6401 .downcast_ref::<UInt64Array>()
6402 .unwrap();
6403 assert_eq!(vids.value(0), 1);
6404 assert_eq!(vids.value(1), 3);
6405 }
6406
6407 #[test]
6408 fn test_anti_join_empty_neg() {
6409 let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
6411 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6412 let result = apply_anti_join_composite(vec![body], &[], &join_cols).unwrap();
6413 assert_eq!(result.len(), 1);
6414 assert_eq!(result[0].num_rows(), 3);
6415 }
6416
6417 #[test]
6418 fn test_anti_join_all_excluded() {
6419 let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
6421 let neg = make_vid_prob_batch(&[1, 2], &[0.0, 0.0]);
6422 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6423 let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
6424 let total: usize = result.iter().map(|b| b.num_rows()).sum();
6425 assert_eq!(total, 0);
6426 }
6427
6428 #[test]
6429 fn test_multiply_prob_single_complement() {
6430 let body = make_vid_prob_batch(&[1], &[0.8]);
6432 let complement_arr = Float64Array::from(vec![0.5]);
6434 let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
6435 cols.push(Arc::new(complement_arr));
6436 let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
6437 fields.push(Arc::new(Field::new(
6438 "__complement_0",
6439 DataType::Float64,
6440 true,
6441 )));
6442 let schema = Arc::new(Schema::new(fields));
6443 let batch = RecordBatch::try_new(schema, cols).unwrap();
6444
6445 let result =
6446 multiply_prob_factors(vec![batch], Some("prob"), &["__complement_0".to_string()])
6447 .unwrap();
6448 assert_eq!(result.len(), 1);
6449 let out = &result[0];
6450 assert!(out.column_by_name("__complement_0").is_none());
6452 let prob = out
6453 .column_by_name("prob")
6454 .unwrap()
6455 .as_any()
6456 .downcast_ref::<Float64Array>()
6457 .unwrap();
6458 assert!(
6459 (prob.value(0) - 0.4).abs() < 1e-10,
6460 "expected 0.4, got {}",
6461 prob.value(0)
6462 );
6463 }
6464
6465 #[test]
6466 fn test_multiply_prob_multiple_complements() {
6467 let body = make_vid_prob_batch(&[1], &[0.8]);
6469 let c1 = Float64Array::from(vec![0.5]);
6470 let c2 = Float64Array::from(vec![0.6]);
6471 let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
6472 cols.push(Arc::new(c1));
6473 cols.push(Arc::new(c2));
6474 let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
6475 fields.push(Arc::new(Field::new("__c1", DataType::Float64, true)));
6476 fields.push(Arc::new(Field::new("__c2", DataType::Float64, true)));
6477 let schema = Arc::new(Schema::new(fields));
6478 let batch = RecordBatch::try_new(schema, cols).unwrap();
6479
6480 let result = multiply_prob_factors(
6481 vec![batch],
6482 Some("prob"),
6483 &["__c1".to_string(), "__c2".to_string()],
6484 )
6485 .unwrap();
6486 let out = &result[0];
6487 assert!(out.column_by_name("__c1").is_none());
6488 assert!(out.column_by_name("__c2").is_none());
6489 let prob = out
6490 .column_by_name("prob")
6491 .unwrap()
6492 .as_any()
6493 .downcast_ref::<Float64Array>()
6494 .unwrap();
6495 assert!(
6496 (prob.value(0) - 0.24).abs() < 1e-10,
6497 "expected 0.24, got {}",
6498 prob.value(0)
6499 );
6500 }
6501
6502 #[test]
6503 fn test_multiply_prob_no_prob_column() {
6504 use arrow_array::UInt64Array;
6506 let schema = Arc::new(Schema::new(vec![
6507 Field::new("vid", DataType::UInt64, true),
6508 Field::new("__c1", DataType::Float64, true),
6509 ]));
6510 let batch = RecordBatch::try_new(
6511 schema,
6512 vec![
6513 Arc::new(UInt64Array::from(vec![1u64])),
6514 Arc::new(Float64Array::from(vec![0.7])),
6515 ],
6516 )
6517 .unwrap();
6518
6519 let result = multiply_prob_factors(vec![batch], None, &["__c1".to_string()]).unwrap();
6520 let out = &result[0];
6521 assert!(out.column_by_name("__c1").is_none());
6523 assert_eq!(out.num_columns(), 1);
6525 }
6526}