1use crate::query::df_graph::GraphExecutionContext;
10use crate::query::df_graph::common::{
11 ScalarKey, arrow_err, collect_all_partitions, compute_plan_properties, execute_subplan,
12 execute_subplan_collecting, extract_scalar_key,
13};
14use crate::query::df_graph::locy_best_by::{BestByExec, SortCriterion};
15use crate::query::df_graph::locy_errors::LocyRuntimeError;
16use crate::query::df_graph::locy_explain::{
17 ProofTerm, ProvenanceAnnotation, ProvenanceStore, compute_proof_probability,
18};
19use crate::query::df_graph::locy_fold::{FoldBinding, FoldExec};
20use crate::query::df_graph::locy_priority::PriorityExec;
21use crate::query::df_graph::locy_profile::LocyProfileCollector;
22use crate::query::df_graph::locy_program::interruption;
23use crate::query::executor::core::OperatorStats;
24use crate::query::planner::LogicalPlan;
25use arrow_array::RecordBatch;
26use arrow_row::{RowConverter, SortField};
27use arrow_schema::SchemaRef;
28use datafusion::common::JoinType;
29use datafusion::common::Result as DFResult;
30use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
31use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
32use datafusion::physical_plan::memory::MemoryStream;
33use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
34use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
35use futures::Stream;
36use parking_lot::RwLock;
37use std::any::Any;
38use std::collections::{HashMap, HashSet};
39use std::fmt;
40use std::pin::Pin;
41use std::sync::{Arc, RwLock as StdRwLock};
42use std::task::{Context, Poll};
43use std::time::{Duration, Instant};
44use uni_common::Value;
45use uni_common::core::schema::Schema as UniSchema;
46use uni_cypher::ast::Expr;
47use uni_locy::{
48 ClassifierRegistry, ModelInvocation, ModelInvocationCache, RuntimeWarning, RuntimeWarningCode,
49 SemiringKind,
50};
51use uni_store::storage::manager::StorageManager;
52
53#[derive(Debug)]
63pub struct DerivedScanEntry {
64 pub scan_index: usize,
66 pub rule_name: String,
68 pub is_self_ref: bool,
70 pub data: Arc<RwLock<Vec<RecordBatch>>>,
72 pub schema: SchemaRef,
74}
75
76#[derive(Debug, Default)]
83pub struct DerivedScanRegistry {
84 entries: Vec<DerivedScanEntry>,
85}
86
87impl DerivedScanRegistry {
88 pub fn new() -> Self {
90 Self::default()
91 }
92
93 pub fn add(&mut self, entry: DerivedScanEntry) {
95 self.entries.push(entry);
96 }
97
98 pub fn get(&self, scan_index: usize) -> Option<&DerivedScanEntry> {
100 self.entries.iter().find(|e| e.scan_index == scan_index)
101 }
102
103 pub fn write_data(&self, scan_index: usize, batches: Vec<RecordBatch>) {
105 if let Some(entry) = self.get(scan_index) {
106 let mut guard = entry.data.write();
107 *guard = batches;
108 }
109 }
110
111 pub fn entries_for_rule(&self, rule_name: &str) -> Vec<&DerivedScanEntry> {
113 self.entries
114 .iter()
115 .filter(|e| e.rule_name == rule_name)
116 .collect()
117 }
118}
119
120#[derive(Debug, Clone)]
131pub struct MonotonicFoldBinding {
132 pub fold_name: String,
133 pub aggregate: std::sync::Arc<dyn uni_plugin::traits::locy::LocyAggregate>,
134 pub input_col_index: usize,
135 pub input_col_name: Option<String>,
137}
138
139#[derive(Debug)]
145pub struct MonotonicAggState {
146 accumulators: HashMap<(Vec<ScalarKey>, String), f64>,
148 prev_snapshot: HashMap<(Vec<ScalarKey>, String), f64>,
150 bindings: Vec<MonotonicFoldBinding>,
152}
153
154impl MonotonicAggState {
155 pub fn new(bindings: Vec<MonotonicFoldBinding>) -> Self {
157 Self {
158 accumulators: HashMap::new(),
159 prev_snapshot: HashMap::new(),
160 bindings,
161 }
162 }
163
164 pub fn update(
190 &mut self,
191 key_indices: &[usize],
192 delta_batches: &[RecordBatch],
193 strict: bool,
194 semiring_kind: SemiringKind,
195 ) -> DFResult<bool> {
196 let mut changed = false;
197 for batch in delta_batches {
198 for row_idx in 0..batch.num_rows() {
199 let group_key = extract_scalar_key(batch, key_indices, row_idx);
200 for binding in &self.bindings {
201 let idx = binding
202 .input_col_name
203 .as_ref()
204 .and_then(|name| batch.schema().index_of(name).ok())
205 .unwrap_or(binding.input_col_index);
206 if idx >= batch.num_columns() {
207 continue;
208 }
209 let col = batch.column(idx);
210 let val = extract_f64(col.as_ref(), row_idx);
211 if let Some(val) = val {
212 let map_key = (group_key.clone(), binding.fold_name.clone());
213 let initial = binding.aggregate.initial_accum_f64().unwrap_or(0.0);
214 let entry = self.accumulators.entry(map_key).or_insert(initial);
215 let old = *entry;
216 if matches!(semiring_kind, SemiringKind::MaxMinProb)
227 && binding.aggregate.is_probability_aggregate()
228 {
229 use uni_locy::LocySemiring;
230 let sr = uni_locy::MaxMinProb;
231 let is_nor = binding.aggregate.is_noisy_or();
232 let label = if is_nor { "MNOR" } else { "MPROD" };
233 if strict && !(0.0..=1.0).contains(&val) {
234 return Err(datafusion::error::DataFusionError::Execution(
235 format!(
236 "strict_probability_domain: {label} input {val} is outside [0, 1]"
237 ),
238 ));
239 }
240 if !strict && !(0.0..=1.0).contains(&val) {
241 tracing::warn!(
242 "{label} input {val} outside [0,1], clamped to {}",
243 val.clamp(0.0, 1.0)
244 );
245 }
246 let p = val.clamp(0.0, 1.0);
247 *entry = if is_nor {
249 sr.plus(entry, &p)
250 } else {
251 sr.times(entry, &p)
252 };
253 if (*entry - old).abs() > f64::EPSILON {
254 changed = true;
255 }
256 continue;
257 }
258 match binding.aggregate.update_step(*entry, val, strict) {
259 Ok(new_val) => {
260 *entry = new_val;
261 if (*entry - old).abs() > f64::EPSILON {
262 changed = true;
263 }
264 }
265 Err(e) if e.code == uni_plugin::FnError::CODE_UNKNOWN_FUNCTION => {
266 }
270 Err(e) => {
271 return Err(datafusion::error::DataFusionError::Execution(
274 e.message,
275 ));
276 }
277 }
278 }
279 }
280 }
281 }
282 Ok(changed)
283 }
284
285 pub fn snapshot(&mut self) {
287 self.prev_snapshot = self.accumulators.clone();
288 }
289
290 pub fn is_stable(&self) -> bool {
292 if self.accumulators.len() != self.prev_snapshot.len() {
293 return false;
294 }
295 for (key, val) in &self.accumulators {
296 match self.prev_snapshot.get(key) {
297 Some(prev) if (*val - *prev).abs() <= f64::EPSILON => {}
298 _ => return false,
299 }
300 }
301 true
302 }
303
304 #[cfg(test)]
306 pub(crate) fn get_accumulator(&self, key: &(Vec<ScalarKey>, String)) -> Option<f64> {
307 self.accumulators.get(key).copied()
308 }
309}
310
311fn extract_f64(col: &dyn arrow_array::Array, row_idx: usize) -> Option<f64> {
313 if col.is_null(row_idx) {
314 return None;
315 }
316 if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Float64Array>() {
317 Some(arr.value(row_idx))
318 } else {
319 col.as_any()
320 .downcast_ref::<arrow_array::Int64Array>()
321 .map(|arr| arr.value(row_idx) as f64)
322 }
323}
324
325struct RowDedupState {
335 converter: RowConverter,
336 seen: HashSet<Box<[u8]>>,
337}
338
339impl RowDedupState {
340 fn try_new(schema: &SchemaRef) -> Option<Self> {
345 let fields: Vec<SortField> = schema
346 .fields()
347 .iter()
348 .map(|f| SortField::new(f.data_type().clone()))
349 .collect();
350 match RowConverter::new(fields) {
351 Ok(converter) => Some(Self {
352 converter,
353 seen: HashSet::new(),
354 }),
355 Err(e) => {
356 tracing::warn!(
357 "RowDedupState: RowConverter unsupported for schema, falling back to legacy dedup: {}",
358 e
359 );
360 None
361 }
362 }
363 }
364
365 fn ingest_existing(&mut self, facts: &[RecordBatch], _schema: &SchemaRef) {
370 self.seen.clear();
371 for batch in facts {
372 if batch.num_rows() == 0 {
373 continue;
374 }
375 let arrays: Vec<_> = batch.columns().to_vec();
376 if let Ok(rows) = self.converter.convert_columns(&arrays) {
377 for row_idx in 0..batch.num_rows() {
378 let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
379 self.seen.insert(row_bytes);
380 }
381 }
382 }
383 }
384
385 fn compute_delta(
391 &mut self,
392 candidates: &[RecordBatch],
393 schema: &SchemaRef,
394 ) -> DFResult<Vec<RecordBatch>> {
395 let mut delta_batches = Vec::new();
396 for batch in candidates {
397 if batch.num_rows() == 0 {
398 continue;
399 }
400
401 let arrays: Vec<_> = batch.columns().to_vec();
403 let rows = self.converter.convert_columns(&arrays).map_err(arrow_err)?;
404
405 let mut keep = Vec::with_capacity(batch.num_rows());
407 for row_idx in 0..batch.num_rows() {
408 let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
409 keep.push(self.seen.insert(row_bytes));
410 }
411
412 let keep_mask = arrow_array::BooleanArray::from(keep);
413 let new_cols = batch
414 .columns()
415 .iter()
416 .map(|col| {
417 arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
418 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
419 })
420 })
421 .collect::<DFResult<Vec<_>>>()?;
422
423 if new_cols.first().is_some_and(|c| !c.is_empty()) {
424 let filtered = RecordBatch::try_new(Arc::clone(schema), new_cols).map_err(|e| {
425 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
426 })?;
427 delta_batches.push(filtered);
428 }
429 }
430 Ok(delta_batches)
431 }
432}
433
434pub struct FixpointState {
444 rule_name: String,
445 facts: Vec<RecordBatch>,
446 delta: Vec<RecordBatch>,
447 schema: SchemaRef,
448 key_column_indices: Vec<usize>,
449 key_column_names: Vec<String>,
451 all_column_indices: Vec<usize>,
453 facts_bytes: usize,
455 max_derived_bytes: usize,
457 monotonic_agg: Option<MonotonicAggState>,
459 row_dedup: Option<RowDedupState>,
461 strict_probability_domain: bool,
463 semiring_kind: SemiringKind,
465}
466
467impl FixpointState {
468 pub fn new(
473 rule_name: String,
474 schema: SchemaRef,
475 key_column_indices: Vec<usize>,
476 max_derived_bytes: usize,
477 monotonic_agg: Option<MonotonicAggState>,
478 strict_probability_domain: bool,
479 ) -> Self {
480 Self::new_with_semiring(
481 rule_name,
482 schema,
483 key_column_indices,
484 max_derived_bytes,
485 monotonic_agg,
486 strict_probability_domain,
487 SemiringKind::AddMultProb,
488 )
489 }
490
491 pub fn new_with_semiring(
492 rule_name: String,
493 schema: SchemaRef,
494 key_column_indices: Vec<usize>,
495 max_derived_bytes: usize,
496 monotonic_agg: Option<MonotonicAggState>,
497 strict_probability_domain: bool,
498 semiring_kind: SemiringKind,
499 ) -> Self {
500 let num_cols = schema.fields().len();
501 let row_dedup = RowDedupState::try_new(&schema);
502 let key_column_names: Vec<String> = key_column_indices
503 .iter()
504 .filter_map(|&i| schema.fields().get(i).map(|f| f.name().clone()))
505 .collect();
506 Self {
507 rule_name,
508 facts: Vec::new(),
509 delta: Vec::new(),
510 schema,
511 key_column_indices,
512 key_column_names,
513 all_column_indices: (0..num_cols).collect(),
514 facts_bytes: 0,
515 max_derived_bytes,
516 monotonic_agg,
517 row_dedup,
518 strict_probability_domain,
519 semiring_kind,
520 }
521 }
522
523 fn reconcile_schema(&mut self, actual_schema: &SchemaRef) {
530 if self.schema.fields() != actual_schema.fields() {
531 tracing::debug!(
532 rule = %self.rule_name,
533 "Reconciling fixpoint schema from physical plan output",
534 );
535 self.schema = Arc::clone(actual_schema);
536 self.row_dedup = RowDedupState::try_new(&self.schema);
537 let new_indices: Vec<usize> = self
541 .key_column_names
542 .iter()
543 .filter_map(|name| actual_schema.index_of(name).ok())
544 .collect();
545 if new_indices.len() == self.key_column_names.len() {
546 self.key_column_indices = new_indices;
547 }
548 let num_cols = actual_schema.fields().len();
550 self.all_column_indices = (0..num_cols).collect();
551 }
552 }
553
554 pub async fn merge_delta(
558 &mut self,
559 candidates: Vec<RecordBatch>,
560 task_ctx: Option<Arc<TaskContext>>,
561 ) -> DFResult<bool> {
562 if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
563 self.delta.clear();
564 return Ok(false);
565 }
566
567 if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
571 self.reconcile_schema(&first.schema());
572 }
573
574 let candidates = round_float_columns(&candidates);
576
577 let delta = self.compute_delta(&candidates, task_ctx.as_ref()).await?;
579
580 if delta.is_empty() || delta.iter().all(|b| b.num_rows() == 0) {
581 self.delta.clear();
582 if let Some(ref mut agg) = self.monotonic_agg {
584 agg.snapshot();
585 }
586 return Ok(false);
587 }
588
589 let delta_bytes: usize = delta.iter().map(batch_byte_size).sum();
591 if self.facts_bytes + delta_bytes > self.max_derived_bytes {
592 return Err(datafusion::error::DataFusionError::Execution(
593 LocyRuntimeError::MemoryLimitExceeded {
594 rule: self.rule_name.clone(),
595 bytes: self.facts_bytes + delta_bytes,
596 limit: self.max_derived_bytes,
597 }
598 .to_string(),
599 ));
600 }
601
602 if let Some(ref mut agg) = self.monotonic_agg {
604 agg.snapshot();
605 agg.update(
606 &self.key_column_indices,
607 &delta,
608 self.strict_probability_domain,
609 self.semiring_kind,
610 )?;
611 }
612
613 self.facts_bytes += delta_bytes;
615 self.facts.extend(delta.iter().cloned());
616 self.delta = delta;
617
618 Ok(true)
619 }
620
621 async fn compute_delta(
628 &mut self,
629 candidates: &[RecordBatch],
630 task_ctx: Option<&Arc<TaskContext>>,
631 ) -> DFResult<Vec<RecordBatch>> {
632 let total_existing: usize = self.facts.iter().map(|b| b.num_rows()).sum();
633 if total_existing >= DEDUP_ANTI_JOIN_THRESHOLD
634 && let Some(ctx) = task_ctx
635 {
636 return arrow_left_anti_dedup(candidates.to_vec(), &self.facts, &self.schema, ctx)
637 .await;
638 }
639 if let Some(ref mut rd) = self.row_dedup {
640 rd.compute_delta(candidates, &self.schema)
641 } else {
642 self.compute_delta_legacy(candidates)
643 }
644 }
645
646 fn compute_delta_legacy(&self, candidates: &[RecordBatch]) -> DFResult<Vec<RecordBatch>> {
650 let mut existing: HashSet<Vec<ScalarKey>> = HashSet::new();
652 for batch in &self.facts {
653 for row_idx in 0..batch.num_rows() {
654 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
655 existing.insert(key);
656 }
657 }
658
659 let mut delta_batches = Vec::new();
660 for batch in candidates {
661 if batch.num_rows() == 0 {
662 continue;
663 }
664 let mut keep = Vec::with_capacity(batch.num_rows());
666 for row_idx in 0..batch.num_rows() {
667 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
668 keep.push(!existing.contains(&key));
669 }
670
671 for (row_idx, kept) in keep.iter_mut().enumerate() {
673 if *kept {
674 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
675 if !existing.insert(key) {
676 *kept = false;
677 }
678 }
679 }
680
681 let keep_mask = arrow_array::BooleanArray::from(keep);
682 let new_rows = batch
683 .columns()
684 .iter()
685 .map(|col| {
686 arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
687 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
688 })
689 })
690 .collect::<DFResult<Vec<_>>>()?;
691
692 if new_rows.first().is_some_and(|c| !c.is_empty()) {
693 let filtered =
694 RecordBatch::try_new(Arc::clone(&self.schema), new_rows).map_err(|e| {
695 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
696 })?;
697 delta_batches.push(filtered);
698 }
699 }
700
701 Ok(delta_batches)
702 }
703
704 pub fn is_converged(&self) -> bool {
706 let delta_empty = self.delta.is_empty() || self.delta.iter().all(|b| b.num_rows() == 0);
707 let agg_stable = self.monotonic_agg.as_ref().is_none_or(|a| a.is_stable());
708 delta_empty && agg_stable
709 }
710
711 pub fn all_facts(&self) -> &[RecordBatch] {
713 &self.facts
714 }
715
716 pub fn all_delta(&self) -> &[RecordBatch] {
718 &self.delta
719 }
720
721 pub fn into_facts(self) -> Vec<RecordBatch> {
723 self.facts
724 }
725
726 pub fn merge_best_by(
737 &mut self,
738 candidates: Vec<RecordBatch>,
739 sort_criteria: &[SortCriterion],
740 ) -> DFResult<bool> {
741 if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
742 self.delta.clear();
743 return Ok(false);
744 }
745
746 if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
748 self.reconcile_schema(&first.schema());
749 }
750
751 let candidates = round_float_columns(&candidates);
753
754 let old_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> =
756 self.build_key_criteria_map(sort_criteria);
757
758 let mut all_batches = self.facts.clone();
760 all_batches.extend(candidates);
761 let all_batches: Vec<_> = all_batches
762 .into_iter()
763 .filter(|b| b.num_rows() > 0)
764 .collect();
765 if all_batches.is_empty() {
766 self.delta.clear();
767 return Ok(false);
768 }
769
770 let combined = arrow::compute::concat_batches(&self.schema, &all_batches)
771 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
772
773 if combined.num_rows() == 0 {
774 self.delta.clear();
775 return Ok(false);
776 }
777
778 let mut sort_columns = Vec::new();
781 for &ki in &self.key_column_indices {
782 if ki >= combined.num_columns() {
783 continue;
784 }
785 sort_columns.push(arrow::compute::SortColumn {
786 values: Arc::clone(combined.column(ki)),
787 options: Some(arrow::compute::SortOptions {
788 descending: false,
789 nulls_first: false,
790 }),
791 });
792 }
793 for criterion in sort_criteria {
794 if criterion.col_index >= combined.num_columns() {
795 continue;
796 }
797 sort_columns.push(arrow::compute::SortColumn {
798 values: Arc::clone(combined.column(criterion.col_index)),
799 options: Some(arrow::compute::SortOptions {
800 descending: !criterion.ascending,
801 nulls_first: criterion.nulls_first,
802 }),
803 });
804 }
805
806 let sorted_indices =
807 arrow::compute::lexsort_to_indices(&sort_columns, None).map_err(arrow_err)?;
808 let sorted_columns: Vec<_> = combined
809 .columns()
810 .iter()
811 .map(|col| arrow::compute::take(col.as_ref(), &sorted_indices, None))
812 .collect::<Result<Vec<_>, _>>()
813 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
814 let sorted = RecordBatch::try_new(Arc::clone(&self.schema), sorted_columns)
815 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
816
817 let mut keep_indices: Vec<u32> = Vec::new();
819 let mut prev_key: Option<Vec<ScalarKey>> = None;
820 for row_idx in 0..sorted.num_rows() {
821 let key = extract_scalar_key(&sorted, &self.key_column_indices, row_idx);
822 let is_new_group = match &prev_key {
823 None => true,
824 Some(prev) => *prev != key,
825 };
826 if is_new_group {
827 keep_indices.push(row_idx as u32);
828 prev_key = Some(key);
829 }
830 }
831
832 let keep_array = arrow_array::UInt32Array::from(keep_indices);
833 let output_columns: Vec<_> = sorted
834 .columns()
835 .iter()
836 .map(|col| arrow::compute::take(col.as_ref(), &keep_array, None))
837 .collect::<Result<Vec<_>, _>>()
838 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
839 let pruned = RecordBatch::try_new(Arc::clone(&self.schema), output_columns)
840 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
841
842 let new_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> = {
844 let mut map = HashMap::new();
845 for row_idx in 0..pruned.num_rows() {
846 let key = extract_scalar_key(&pruned, &self.key_column_indices, row_idx);
847 let criteria: Vec<ScalarKey> = sort_criteria
848 .iter()
849 .flat_map(|c| extract_scalar_key(&pruned, &[c.col_index], row_idx))
850 .collect();
851 map.insert(key, criteria);
852 }
853 map
854 };
855 let changed = old_best != new_best;
856
857 tracing::debug!(
858 rule = %self.rule_name,
859 old_keys = old_best.len(),
860 new_keys = new_best.len(),
861 changed = changed,
862 "BEST BY merge"
863 );
864
865 self.facts_bytes = batch_byte_size(&pruned);
867 self.facts = vec![pruned];
868 if changed {
869 self.delta = self.facts.clone();
872 } else {
873 self.delta.clear();
874 }
875
876 self.row_dedup = RowDedupState::try_new(&self.schema);
878 if let Some(ref mut rd) = self.row_dedup {
879 rd.ingest_existing(&self.facts, &self.schema);
880 }
881
882 Ok(changed)
883 }
884
885 fn build_key_criteria_map(
887 &self,
888 sort_criteria: &[SortCriterion],
889 ) -> HashMap<Vec<ScalarKey>, Vec<ScalarKey>> {
890 let mut map = HashMap::new();
891 for batch in &self.facts {
892 for row_idx in 0..batch.num_rows() {
893 let key = extract_scalar_key(batch, &self.key_column_indices, row_idx);
894 let criteria: Vec<ScalarKey> = sort_criteria
895 .iter()
896 .flat_map(|c| extract_scalar_key(batch, &[c.col_index], row_idx))
897 .collect();
898 map.insert(key, criteria);
899 }
900 }
901 map
902 }
903}
904
905fn batch_byte_size(batch: &RecordBatch) -> usize {
907 batch
908 .columns()
909 .iter()
910 .map(|col| col.get_buffer_memory_size())
911 .sum()
912}
913
914fn round_float_columns(batches: &[RecordBatch]) -> Vec<RecordBatch> {
920 batches
921 .iter()
922 .map(|batch| {
923 let schema = batch.schema();
924 let has_float = schema
925 .fields()
926 .iter()
927 .any(|f| *f.data_type() == arrow_schema::DataType::Float64);
928 if !has_float {
929 return batch.clone();
930 }
931
932 let columns: Vec<arrow_array::ArrayRef> = batch
933 .columns()
934 .iter()
935 .enumerate()
936 .map(|(i, col)| {
937 if *schema.field(i).data_type() == arrow_schema::DataType::Float64 {
938 let arr = col
939 .as_any()
940 .downcast_ref::<arrow_array::Float64Array>()
941 .unwrap();
942 let rounded: arrow_array::Float64Array = arr
943 .iter()
944 .map(|v| v.map(|f| (f * 1e12).round() / 1e12))
945 .collect();
946 Arc::new(rounded) as arrow_array::ArrayRef
947 } else {
948 Arc::clone(col)
949 }
950 })
951 .collect();
952
953 RecordBatch::try_new(schema, columns).unwrap_or_else(|_| batch.clone())
954 })
955 .collect()
956}
957
958const DEDUP_ANTI_JOIN_THRESHOLD: usize = 300;
968
969fn dedup_batches_all_columns(
984 batches: Vec<RecordBatch>,
985 schema: &SchemaRef,
986) -> DFResult<Vec<RecordBatch>> {
987 let fields: Vec<SortField> = schema
988 .fields()
989 .iter()
990 .map(|f| SortField::new(f.data_type().clone()))
991 .collect();
992 let Ok(converter) = RowConverter::new(fields) else {
996 return Ok(batches);
997 };
998 let mut seen: HashSet<Box<[u8]>> = HashSet::new();
999 let mut out = Vec::with_capacity(batches.len());
1000 for batch in batches {
1001 if batch.num_rows() == 0 {
1002 continue;
1003 }
1004 let rows = converter
1005 .convert_columns(batch.columns())
1006 .map_err(arrow_err)?;
1007 let mut keep = Vec::with_capacity(batch.num_rows());
1008 for row_idx in 0..batch.num_rows() {
1009 let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
1010 keep.push(seen.insert(row_bytes));
1011 }
1012 let keep_mask = arrow_array::BooleanArray::from(keep);
1013 let cols = batch
1014 .columns()
1015 .iter()
1016 .map(|c| arrow::compute::filter(c.as_ref(), &keep_mask).map_err(arrow_err))
1017 .collect::<DFResult<Vec<_>>>()?;
1018 if cols.first().is_some_and(|c| !c.is_empty()) {
1019 out.push(RecordBatch::try_new(Arc::clone(schema), cols).map_err(arrow_err)?);
1020 }
1021 }
1022 Ok(out)
1023}
1024
1025async fn arrow_left_anti_dedup(
1026 candidates: Vec<RecordBatch>,
1027 existing: &[RecordBatch],
1028 schema: &SchemaRef,
1029 task_ctx: &Arc<TaskContext>,
1030) -> DFResult<Vec<RecordBatch>> {
1031 if existing.is_empty() || existing.iter().all(|b| b.num_rows() == 0) {
1032 return dedup_batches_all_columns(candidates, schema);
1035 }
1036
1037 let left: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(candidates, Arc::clone(schema)));
1038 let right: Arc<dyn ExecutionPlan> =
1039 Arc::new(InMemoryExec::new(existing.to_vec(), Arc::clone(schema)));
1040
1041 let on: Vec<(
1042 Arc<dyn datafusion::physical_plan::PhysicalExpr>,
1043 Arc<dyn datafusion::physical_plan::PhysicalExpr>,
1044 )> = schema
1045 .fields()
1046 .iter()
1047 .enumerate()
1048 .map(|(i, field)| {
1049 let l: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
1050 datafusion::physical_plan::expressions::Column::new(field.name(), i),
1051 );
1052 let r: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
1053 datafusion::physical_plan::expressions::Column::new(field.name(), i),
1054 );
1055 (l, r)
1056 })
1057 .collect();
1058
1059 if on.is_empty() {
1060 return Ok(vec![]);
1061 }
1062
1063 let join = HashJoinExec::try_new(
1064 left,
1065 right,
1066 on,
1067 None,
1068 &JoinType::LeftAnti,
1069 None,
1070 PartitionMode::CollectLeft,
1071 datafusion::common::NullEquality::NullEqualsNull,
1072 false,
1077 )?;
1078
1079 let join_arc: Arc<dyn ExecutionPlan> = Arc::new(join);
1080 let anti = collect_all_partitions(&join_arc, task_ctx.clone()).await?;
1083 dedup_batches_all_columns(anti, schema)
1084}
1085
1086#[derive(Debug, Clone)]
1092pub struct IsRefBinding {
1093 pub derived_scan_index: usize,
1095 pub rule_name: String,
1097 pub is_self_ref: bool,
1099 pub negated: bool,
1101 pub anti_join_cols: Vec<(String, String)>,
1107 pub target_has_prob: bool,
1109 pub target_prob_col: Option<String>,
1111 pub provenance_join_cols: Vec<(String, String)>,
1116}
1117
1118#[derive(Debug)]
1120pub struct FixpointClausePlan {
1121 pub body_logical: LogicalPlan,
1123 pub is_ref_bindings: Vec<IsRefBinding>,
1125 pub priority: Option<i64>,
1127 pub along_bindings: Vec<String>,
1129 pub model_invocations: Vec<ModelInvocation>,
1133}
1134
1135#[derive(Debug)]
1137pub struct FixpointRulePlan {
1138 pub name: String,
1140 pub clauses: Vec<FixpointClausePlan>,
1142 pub yield_schema: SchemaRef,
1144 pub key_column_indices: Vec<usize>,
1146 pub priority: Option<i64>,
1148 pub has_fold: bool,
1150 pub fold_bindings: Vec<FoldBinding>,
1152 pub having: Vec<Expr>,
1154 pub has_best_by: bool,
1156 pub best_by_criteria: Vec<SortCriterion>,
1158 pub has_priority: bool,
1160 pub deterministic: bool,
1164 pub prob_column_name: Option<String>,
1166 pub non_linear: bool,
1173}
1174
1175#[expect(clippy::too_many_arguments, reason = "Fixpoint loop needs all context")]
1184async fn run_fixpoint_loop(
1185 rules: Vec<FixpointRulePlan>,
1186 max_iterations: usize,
1187 timeout: Duration,
1188 graph_ctx: Arc<GraphExecutionContext>,
1189 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
1190 storage: Arc<StorageManager>,
1191 schema_info: Arc<UniSchema>,
1192 params: HashMap<String, Value>,
1193 registry: Arc<DerivedScanRegistry>,
1194 output_schema: SchemaRef,
1195 max_derived_bytes: usize,
1196 derivation_tracker: Option<Arc<ProvenanceStore>>,
1197 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
1198 strict_probability_domain: bool,
1199 probability_epsilon: f64,
1200 exact_probability: bool,
1201 max_bdd_variables: usize,
1202 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
1203 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
1204 top_k_proofs: usize,
1205 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
1206 semiring_kind: SemiringKind,
1207 classifier_registry: Arc<ClassifierRegistry>,
1208 classifier_cache: Option<Arc<ModelInvocationCache>>,
1209 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
1210 profile_collector: Option<Arc<LocyProfileCollector>>,
1211) -> DFResult<Vec<RecordBatch>> {
1212 let start = Instant::now();
1213 let task_ctx = session_ctx.read().task_ctx();
1214
1215 if semiring_kind == SemiringKind::MaxMinProb {
1220 let mut warnings = warnings_slot.write().unwrap_or_else(|e| e.into_inner());
1221 let mut already_warned: HashSet<String> = warnings
1222 .iter()
1223 .filter(|w| w.code == RuntimeWarningCode::FuzzyNotProbabilistic)
1224 .map(|w| w.rule_name.clone())
1225 .collect();
1226 for rule in &rules {
1227 if rule.prob_column_name.is_some() && !already_warned.contains(&rule.name) {
1228 warnings.push(RuntimeWarning {
1229 code: RuntimeWarningCode::FuzzyNotProbabilistic,
1230 message: format!(
1231 "rule '{}' carries a PROB column but is being evaluated under \
1232 the MaxMinProb (fuzzy / Viterbi) semiring; outputs are fuzzy \
1233 truth values, not probabilities",
1234 rule.name
1235 ),
1236 rule_name: rule.name.clone(),
1237 variable_count: None,
1238 key_group: None,
1239 });
1240 already_warned.insert(rule.name.clone());
1241 }
1242 }
1243 }
1244
1245 let mut states: Vec<FixpointState> = rules
1247 .iter()
1248 .map(|rule| {
1249 let monotonic_agg = if !rule.fold_bindings.is_empty() {
1250 let bindings: Vec<MonotonicFoldBinding> = rule
1251 .fold_bindings
1252 .iter()
1253 .map(|fb| MonotonicFoldBinding {
1254 fold_name: fb.output_name.clone(),
1255 aggregate: std::sync::Arc::clone(&fb.aggregate),
1256 input_col_index: fb.input_col_index,
1257 input_col_name: fb.input_col_name.clone(),
1258 })
1259 .collect();
1260 Some(MonotonicAggState::new(bindings))
1261 } else {
1262 None
1263 };
1264 FixpointState::new_with_semiring(
1265 rule.name.clone(),
1266 Arc::clone(&rule.yield_schema),
1267 rule.key_column_indices.clone(),
1268 max_derived_bytes,
1269 monotonic_agg,
1270 strict_probability_domain,
1271 semiring_kind,
1272 )
1273 })
1274 .collect();
1275
1276 let mut converged = false;
1278 let mut total_iters = 0usize;
1279 for iteration in 0..max_iterations {
1280 total_iters = iteration + 1;
1281 tracing::debug!("fixpoint iteration {}", iteration);
1282 let mut any_changed = false;
1283
1284 for rule_idx in 0..rules.len() {
1285 let rule = &rules[rule_idx];
1286
1287 update_derived_scan_handles(®istry, &states, rule_idx, &rules);
1289
1290 let mut all_candidates = Vec::new();
1292 let mut clause_candidates: Vec<Vec<RecordBatch>> = Vec::new();
1293 let rule_start = Instant::now();
1296 let mut iter_ops: Vec<OperatorStats> = Vec::new();
1297 for clause in &rule.clauses {
1298 let mut batches = if profile_collector.is_some() {
1304 let (b, ops) = execute_subplan_collecting(
1305 &clause.body_logical,
1306 ¶ms,
1307 &HashMap::new(),
1308 &graph_ctx,
1309 &session_ctx,
1310 &storage,
1311 &schema_info,
1312 None, )
1314 .await?;
1315 iter_ops.extend(ops);
1316 b
1317 } else {
1318 execute_subplan(
1319 &clause.body_logical,
1320 ¶ms,
1321 &HashMap::new(),
1322 &graph_ctx,
1323 &session_ctx,
1324 &storage,
1325 &schema_info,
1326 None, )
1328 .await?
1329 };
1330 for binding in &clause.is_ref_bindings {
1332 if binding.negated
1333 && !binding.anti_join_cols.is_empty()
1334 && let Some(entry) = registry.get(binding.derived_scan_index)
1335 {
1336 let neg_facts = entry.data.read().clone();
1337 if !neg_facts.is_empty() {
1338 if binding.target_has_prob && rule.prob_column_name.is_some() {
1339 let complement_col =
1341 format!("__prob_complement_{}", binding.rule_name);
1342 if let Some(prob_col) = &binding.target_prob_col {
1343 batches = apply_prob_complement_composite(
1344 batches,
1345 &neg_facts,
1346 &binding.anti_join_cols,
1347 prob_col,
1348 &complement_col,
1349 )?;
1350 } else {
1351 batches = apply_anti_join_composite(
1353 batches,
1354 &neg_facts,
1355 &binding.anti_join_cols,
1356 )?;
1357 }
1358 } else {
1359 batches = apply_anti_join_composite(
1361 batches,
1362 &neg_facts,
1363 &binding.anti_join_cols,
1364 )?;
1365 }
1366 }
1367 }
1368 }
1369 let complement_cols: Vec<String> = if !batches.is_empty() {
1371 batches[0]
1372 .schema()
1373 .fields()
1374 .iter()
1375 .filter(|f| f.name().starts_with("__prob_complement_"))
1376 .map(|f| f.name().clone())
1377 .collect()
1378 } else {
1379 vec![]
1380 };
1381 if !complement_cols.is_empty() {
1382 batches = multiply_prob_factors(
1383 batches,
1384 rule.prob_column_name.as_deref(),
1385 &complement_cols,
1386 )?;
1387 }
1388
1389 clause_candidates.push(batches.clone());
1390 all_candidates.extend(batches);
1391 }
1392
1393 let changed = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
1397 states[rule_idx].merge_best_by(all_candidates, &rule.best_by_criteria)?
1398 } else {
1399 states[rule_idx]
1400 .merge_delta(all_candidates, Some(Arc::clone(&task_ctx)))
1401 .await?
1402 };
1403 if changed {
1404 any_changed = true;
1405 if let Some(ref tracker) = derivation_tracker {
1407 record_provenance(
1408 ProvenanceCtx {
1409 tracker,
1410 registry: ®istry,
1411 warnings_slot: &warnings_slot,
1412 },
1413 rule,
1414 &states[rule_idx],
1415 &clause_candidates,
1416 iteration,
1417 top_k_proofs,
1418 ClassifierRefs {
1419 registry: &classifier_registry,
1420 cache: classifier_cache.as_ref(),
1421 provenance_store: classifier_provenance_store.as_ref(),
1422 },
1423 )
1424 .await;
1425 }
1426 }
1427
1428 if let Some(ref collector) = profile_collector {
1431 let delta_facts: usize = states[rule_idx]
1432 .all_delta()
1433 .iter()
1434 .map(|b| b.num_rows())
1435 .sum();
1436 collector.record(
1437 &rule.name,
1438 iteration,
1439 delta_facts,
1440 rule_start.elapsed().as_secs_f64() * 1000.0,
1441 iter_ops,
1442 );
1443 }
1444 }
1445
1446 if !any_changed && states.iter().all(|s| s.is_converged()) {
1448 tracing::debug!("fixpoint converged after {} iterations", iteration + 1);
1449 converged = true;
1450 break;
1451 }
1452
1453 if start.elapsed() > timeout {
1455 tracing::warn!(
1456 "fixpoint timeout after {} iterations; returning partial results",
1457 iteration + 1,
1458 );
1459 interruption::set(&timeout_flag, interruption::TIMEOUT);
1460 break;
1461 }
1462 }
1463
1464 if let Ok(mut counts) = iteration_counts.write() {
1466 for rule in &rules {
1467 counts.insert(rule.name.clone(), total_iters);
1468 }
1469 }
1470
1471 if let Some(ref collector) = profile_collector {
1473 for (idx, rule) in rules.iter().enumerate() {
1474 let facts: usize = states[idx].all_facts().iter().map(|b| b.num_rows()).sum();
1475 collector.set_final_facts(&rule.name, facts);
1476 }
1477 }
1478
1479 if !converged && interruption::reason(&timeout_flag).is_none() {
1484 tracing::warn!(
1485 "fixpoint did not converge after {max_iterations} iterations; returning partial results",
1486 );
1487 interruption::set(&timeout_flag, interruption::ITERATION_LIMIT);
1488 }
1489
1490 let task_ctx = session_ctx.read().task_ctx();
1492 let mut all_output = Vec::new();
1493
1494 for (rule_idx, state) in states.into_iter().enumerate() {
1495 let rule = &rules[rule_idx];
1496 let mut facts = state.into_facts();
1497 if facts.is_empty() {
1498 continue;
1499 }
1500
1501 let shared_info = if semiring_kind == SemiringKind::MaxMinProb {
1519 None
1520 } else if let Some(ref tracker) = derivation_tracker {
1521 detect_shared_lineage(rule, &facts, tracker, &warnings_slot, semiring_kind)
1522 } else {
1523 None
1524 };
1525
1526 if exact_probability
1528 && let Some(ref info) = shared_info
1529 && let Some(ref tracker) = derivation_tracker
1530 {
1531 facts = apply_exact_wmc(
1532 facts,
1533 rule,
1534 info,
1535 tracker,
1536 max_bdd_variables,
1537 &warnings_slot,
1538 &approximate_slot,
1539 )?;
1540 }
1541
1542 let processed = apply_post_fixpoint_chain(
1543 facts,
1544 rule,
1545 &task_ctx,
1546 strict_probability_domain,
1547 probability_epsilon,
1548 semiring_kind,
1549 derivation_tracker.as_ref().map(Arc::clone),
1550 top_k_proofs,
1551 Some(Arc::clone(®istry)),
1552 )
1553 .await?;
1554 all_output.extend(processed);
1555 }
1556
1557 if all_output.is_empty() {
1559 all_output.push(RecordBatch::new_empty(output_schema));
1560 }
1561
1562 Ok(all_output)
1563}
1564
1565pub(crate) struct ClassifierRefs<'a> {
1577 pub registry: &'a Arc<ClassifierRegistry>,
1578 pub cache: Option<&'a Arc<uni_locy::ModelInvocationCache>>,
1579 pub provenance_store: Option<&'a Arc<uni_locy::NeuralProvenanceStore>>,
1586}
1587
1588pub(crate) struct ProvenanceCtx<'a> {
1594 pub tracker: &'a Arc<ProvenanceStore>,
1595 pub registry: &'a Arc<DerivedScanRegistry>,
1596 pub warnings_slot: &'a Arc<StdRwLock<Vec<RuntimeWarning>>>,
1597}
1598
1599async fn record_provenance(
1600 prov: ProvenanceCtx<'_>,
1601 rule: &FixpointRulePlan,
1602 state: &FixpointState,
1603 clause_candidates: &[Vec<RecordBatch>],
1604 iteration: usize,
1605 top_k_proofs: usize,
1606 classifiers: ClassifierRefs<'_>,
1607) {
1608 let tracker = prov.tracker;
1609 let registry = prov.registry;
1610 let warnings_slot = prov.warnings_slot;
1611 let classifier_registry = classifiers.registry;
1612 let classifier_cache = classifiers.cache;
1613 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1614
1615 let base_probs = if top_k_proofs > 0 {
1617 tracker.base_fact_probs()
1618 } else {
1619 HashMap::new()
1620 };
1621
1622 let mut topk_acc = TopKProofAccumulator::new();
1623
1624 for delta_batch in state.all_delta() {
1625 for row_idx in 0..delta_batch.num_rows() {
1626 let row_hash = format!(
1627 "{:?}",
1628 extract_scalar_key(delta_batch, &all_indices, row_idx)
1629 )
1630 .into_bytes();
1631 let fact_row = batch_row_to_value_map(delta_batch, row_idx);
1632 let clause_index =
1633 find_clause_for_row(delta_batch, row_idx, &all_indices, clause_candidates);
1634
1635 let support = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
1636
1637 let proof_probability = if top_k_proofs > 0 {
1638 compute_proof_probability(&support, &base_probs)
1639 } else {
1640 None
1641 };
1642
1643 let entry = ProvenanceAnnotation {
1644 rule_name: rule.name.clone(),
1645 clause_index,
1646 support,
1647 along_values: {
1648 let along_names: Vec<String> = rule
1649 .clauses
1650 .get(clause_index)
1651 .map(|c| c.along_bindings.clone())
1652 .unwrap_or_default();
1653 along_names
1654 .iter()
1655 .filter_map(|name| fact_row.get(name).map(|v| (name.clone(), v.clone())))
1656 .collect()
1657 },
1658 iteration,
1659 fact_row: fact_row.clone(),
1660 proof_probability,
1661 neural_calls: collect_neural_calls_for_row(
1662 rule,
1663 clause_index,
1664 &fact_row,
1665 classifier_registry,
1666 classifier_cache,
1667 classifiers.provenance_store,
1668 )
1669 .await,
1670 };
1671 if top_k_proofs > 0 {
1672 topk_acc.accumulate(&entry, &row_hash);
1673 tracker.record_top_k(row_hash, entry, top_k_proofs);
1674 } else {
1675 tracker.record(row_hash, entry);
1676 }
1677 }
1678 }
1679
1680 topk_acc.emit_warning_if_any(rule, top_k_proofs, warnings_slot);
1681}
1682
1683struct TopKProofAccumulator {
1690 per_fact: HashMap<Vec<u8>, Vec<uni_locy::Proof>>,
1691 base_rv_interner: HashMap<Vec<u8>, uni_locy::BaseRv>,
1692 next_rv: u32,
1693}
1694
1695impl TopKProofAccumulator {
1696 fn new() -> Self {
1697 Self {
1698 per_fact: HashMap::new(),
1699 base_rv_interner: HashMap::new(),
1700 next_rv: 0,
1701 }
1702 }
1703
1704 fn accumulate(&mut self, entry: &ProvenanceAnnotation, row_hash: &[u8]) {
1705 let mut base_rvs = uni_locy::BaseRvSet::empty();
1706 for term in &entry.support {
1707 let rv = *self
1708 .base_rv_interner
1709 .entry(term.base_fact_id.clone())
1710 .or_insert_with(|| {
1711 let r = uni_locy::BaseRv(self.next_rv);
1712 self.next_rv += 1;
1713 r
1714 });
1715 base_rvs.insert(rv);
1716 }
1717 self.per_fact
1718 .entry(row_hash.to_vec())
1719 .or_default()
1720 .push(uni_locy::Proof {
1721 weight: entry.proof_probability.unwrap_or(0.0),
1722 base_rvs,
1723 neural_calls: Vec::new(),
1724 });
1725 }
1726
1727 fn emit_warning_if_any(
1728 &self,
1729 rule: &FixpointRulePlan,
1730 top_k_proofs: usize,
1731 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1732 ) {
1733 if top_k_proofs == 0 || self.per_fact.is_empty() {
1734 return;
1735 }
1736 let crossed_facts = self
1737 .per_fact
1738 .values()
1739 .filter(|proofs| {
1740 let (_kept, notice) =
1741 uni_locy::merge_top_k_runtime(Vec::new(), (*proofs).clone(), top_k_proofs);
1742 notice == uni_locy::PruneNotice::CrossedDependency
1743 })
1744 .count();
1745 if crossed_facts == 0 {
1746 return;
1747 }
1748 let Ok(mut w) = warnings_slot.write() else {
1749 return;
1750 };
1751 let already = w.iter().any(|rw| {
1752 matches!(
1753 rw.code,
1754 uni_locy::types::RuntimeWarningCode::TopKPruningCrossedDependency
1755 ) && rw.rule_name == rule.name
1756 });
1757 if already {
1758 return;
1759 }
1760 w.push(RuntimeWarning {
1761 code: uni_locy::types::RuntimeWarningCode::TopKPruningCrossedDependency,
1762 rule_name: rule.name.clone(),
1763 message: format!(
1764 "rule '{}': top-K proof pruning (k={}) discarded {} fact(s) \
1765 whose dependencies overlap retained proofs. The retained \
1766 top-{} under-counts the true joint probability for those \
1767 facts (Scallop, Huang et al. 2021). Increase k to recover.",
1768 rule.name, top_k_proofs, crossed_facts, top_k_proofs
1769 ),
1770 variable_count: None,
1771 key_group: None,
1772 });
1773 }
1774}
1775
1776async fn collect_neural_calls_for_row(
1798 rule: &FixpointRulePlan,
1799 clause_index: usize,
1800 fact_row: &uni_locy::FactRow,
1801 classifier_registry: &Arc<ClassifierRegistry>,
1802 classifier_cache: Option<&Arc<uni_locy::ModelInvocationCache>>,
1803 provenance_store: Option<&Arc<uni_locy::NeuralProvenanceStore>>,
1804) -> Vec<uni_locy::NeuralProvenance> {
1805 let Some(clause) = rule.clauses.get(clause_index) else {
1806 return Vec::new();
1807 };
1808 if clause.model_invocations.is_empty() {
1809 return Vec::new();
1810 }
1811 let mut out = Vec::with_capacity(clause.model_invocations.len());
1812 for invocation in &clause.model_invocations {
1813 let mut features = std::collections::HashMap::new();
1824 for (binding_name, feat_expr) in invocation
1825 .feature_names
1826 .iter()
1827 .zip(invocation.feature_exprs.iter())
1828 {
1829 features.insert(
1830 binding_name.clone(),
1831 eval_feature_expr_against_fact_row(feat_expr, fact_row),
1832 );
1833 }
1834 let input = uni_locy::ClassifyInput { features };
1835 let input_hash = input.stable_hash();
1836
1837 if let Some(store) = provenance_store
1844 && let Some(record) = store.get(&invocation.model_name, input_hash)
1845 {
1846 out.push(uni_locy::NeuralProvenance {
1847 model_name: invocation.model_name.clone(),
1848 raw_probability: record.raw_probability,
1849 calibrated_probability: record.calibrated_probability,
1850 confidence_band: record.confidence_band,
1851 });
1852 continue;
1853 }
1854
1855 let Some(classifier) = classifier_registry.get(&invocation.model_name) else {
1860 continue;
1861 };
1862 let raw = if let Some(v) =
1863 classifier_cache.and_then(|c| c.get(&invocation.model_name, input_hash))
1864 {
1865 v
1866 } else {
1867 match classifier.classify(std::slice::from_ref(&input)).await {
1868 Ok(probs) => {
1869 let v = probs.first().copied().unwrap_or(0.0);
1870 if let Some(c) = classifier_cache {
1871 c.insert(&invocation.model_name, input_hash, v);
1872 }
1873 v
1874 }
1875 Err(_) => continue,
1876 }
1877 };
1878 let calibrator = classifier.get_calibrator();
1879 let calibrated_probability = calibrator.as_ref().map(|_| raw);
1880 let confidence_band = calibrator.as_ref().and_then(|c| c.confidence_band(raw));
1881 out.push(uni_locy::NeuralProvenance {
1882 model_name: invocation.model_name.clone(),
1883 raw_probability: raw,
1884 calibrated_probability,
1885 confidence_band,
1886 });
1887 }
1888 out
1889}
1890
1891fn eval_feature_expr_against_fact_row(
1899 expr: &uni_cypher::ast::Expr,
1900 fact_row: &uni_locy::FactRow,
1901) -> uni_locy::FeatureValue {
1902 use uni_cypher::ast::Expr;
1903 use uni_locy::FeatureValue;
1904 let value_to_feature = |v: Option<&uni_common::Value>| -> FeatureValue {
1905 match v {
1906 Some(uni_common::Value::Float(f)) => FeatureValue::Float(*f),
1907 Some(uni_common::Value::Int(i)) => FeatureValue::Int(*i),
1908 Some(uni_common::Value::Bool(b)) => FeatureValue::Bool(*b),
1909 Some(uni_common::Value::String(s)) => FeatureValue::String(s.clone()),
1910 Some(uni_common::Value::Node(n)) => {
1911 FeatureValue::Int(n.vid.as_u64() as i64)
1913 }
1914 _ => FeatureValue::Null,
1915 }
1916 };
1917 let resolve_value = |sub: &Expr| -> uni_common::Value {
1921 match sub {
1922 Expr::Variable(name) => fact_row
1923 .get(name)
1924 .cloned()
1925 .unwrap_or(uni_common::Value::Null),
1926 Expr::Property(boxed, prop) if matches!(boxed.as_ref(), Expr::Variable(_)) => {
1927 let Expr::Variable(v) = boxed.as_ref() else {
1928 unreachable!()
1929 };
1930 let key = format!("{}.{}", v, prop);
1931 if let Some(val) = fact_row.get(&key) {
1932 return val.clone();
1933 }
1934 if let Some(uni_common::Value::Node(n)) = fact_row.get(v) {
1935 return n
1936 .properties
1937 .get(prop)
1938 .cloned()
1939 .unwrap_or(uni_common::Value::Null);
1940 }
1941 uni_common::Value::Null
1942 }
1943 Expr::Literal(lit) => lit.to_value(),
1944 Expr::List(items) => {
1945 let mut out = Vec::with_capacity(items.len());
1946 for it in items {
1947 out.push(match it {
1948 Expr::Literal(lit) => lit.to_value(),
1949 _ => uni_common::Value::Null,
1950 });
1951 }
1952 uni_common::Value::List(out)
1953 }
1954 _ => uni_common::Value::Null,
1955 }
1956 };
1957
1958 match expr {
1959 Expr::Variable(name) => value_to_feature(fact_row.get(name)),
1960 Expr::Property(boxed, prop) => {
1961 if let Expr::Variable(v) = boxed.as_ref() {
1962 let key = format!("{}.{}", v, prop);
1964 if let Some(val) = fact_row.get(&key) {
1965 return value_to_feature(Some(val));
1966 }
1967 let hidden_key = format!("__feat_{}_{}", v, prop);
1977 if let Some(val) = fact_row.get(&hidden_key) {
1978 return value_to_feature(Some(val));
1979 }
1980 if let Some(uni_common::Value::Node(n)) = fact_row.get(v) {
1984 return value_to_feature(n.properties.get(prop));
1985 }
1986 }
1987 FeatureValue::Null
1988 }
1989 Expr::FunctionCall { name, args, .. } if name == "similar_to" && args.len() == 2 => {
1990 let lv = resolve_value(&args[0]);
1991 let rv = resolve_value(&args[1]);
1992 match crate::query::similar_to::eval_similar_to_pure(&lv, &rv) {
1993 Ok(uni_common::Value::Float(f)) => FeatureValue::Float(f),
1994 _ => FeatureValue::Null,
1995 }
1996 }
1997 Expr::FunctionCall { name, .. }
2012 if matches!(
2013 name.as_str(),
2014 "degree_centrality"
2015 | "pagerank_score"
2016 | "closeness_centrality"
2017 | "betweenness_centrality"
2018 | "eigenvector_centrality"
2019 | "harmonic_centrality"
2020 | "katz_centrality"
2021 | "avg_neighbor"
2022 | "max_neighbor"
2023 | "sum_neighbor"
2024 ) =>
2025 {
2026 FeatureValue::Null
2027 }
2028 _ => FeatureValue::Null,
2029 }
2030}
2031
2032fn collect_is_ref_inputs(
2033 rule: &FixpointRulePlan,
2034 clause_index: usize,
2035 delta_batch: &RecordBatch,
2036 row_idx: usize,
2037 registry: &Arc<DerivedScanRegistry>,
2038) -> Vec<ProofTerm> {
2039 let clause = match rule.clauses.get(clause_index) {
2040 Some(c) => c,
2041 None => return vec![],
2042 };
2043
2044 let mut inputs = Vec::new();
2045 let delta_schema = delta_batch.schema();
2046
2047 for binding in &clause.is_ref_bindings {
2048 if binding.negated {
2049 continue;
2050 }
2051 if binding.provenance_join_cols.is_empty() {
2052 continue;
2053 }
2054
2055 let body_values: Vec<(String, ScalarKey)> = binding
2057 .provenance_join_cols
2058 .iter()
2059 .filter_map(|(body_col, _derived_col)| {
2060 let col_idx = delta_schema
2061 .fields()
2062 .iter()
2063 .position(|f| f.name() == body_col)?;
2064 let key = extract_scalar_key(delta_batch, &[col_idx], row_idx);
2065 Some((body_col.clone(), key.into_iter().next()?))
2066 })
2067 .collect();
2068
2069 if body_values.len() != binding.provenance_join_cols.len() {
2070 continue;
2071 }
2072
2073 let entry = match registry.get(binding.derived_scan_index) {
2075 Some(e) => e,
2076 None => continue,
2077 };
2078 let source_batches = entry.data.read();
2079 let source_schema = &entry.schema;
2080
2081 for src_batch in source_batches.iter() {
2083 let all_src_indices: Vec<usize> = (0..src_batch.num_columns()).collect();
2084 for src_row in 0..src_batch.num_rows() {
2085 let matches = binding.provenance_join_cols.iter().enumerate().all(
2086 |(i, (_body_col, derived_col))| {
2087 let src_col_idx = source_schema
2088 .fields()
2089 .iter()
2090 .position(|f| f.name() == derived_col);
2091 match src_col_idx {
2092 Some(idx) => {
2093 let src_key = extract_scalar_key(src_batch, &[idx], src_row);
2094 src_key.first() == Some(&body_values[i].1)
2095 }
2096 None => false,
2097 }
2098 },
2099 );
2100 if matches {
2101 let fact_hash = format!(
2102 "{:?}",
2103 extract_scalar_key(src_batch, &all_src_indices, src_row)
2104 )
2105 .into_bytes();
2106 inputs.push(ProofTerm {
2107 source_rule: binding.rule_name.clone(),
2108 base_fact_id: fact_hash,
2109 });
2110 }
2111 }
2112 }
2113 }
2114
2115 inputs
2116}
2117
2118fn collect_is_ref_inputs_for_body_row(
2140 rule: &FixpointRulePlan,
2141 delta_batch: &RecordBatch,
2142 row_idx: usize,
2143 registry: &Arc<DerivedScanRegistry>,
2144) -> Vec<ProofTerm> {
2145 let mut combined: Vec<ProofTerm> = Vec::new();
2146 for clause_index in 0..rule.clauses.len() {
2147 let part = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
2148 combined.extend(part);
2149 }
2150 combined
2151}
2152
2153#[expect(
2172 dead_code,
2173 reason = "Fields accessed via SharedLineageInfo in detect_shared_lineage"
2174)]
2175pub(crate) struct SharedGroupRow {
2176 pub fact_hash: Vec<u8>,
2177 pub lineage: HashSet<Vec<u8>>,
2178}
2179
2180pub(crate) struct SharedLineageInfo {
2182 pub shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>>,
2184}
2185
2186pub(crate) fn fact_hash_key(batch: &RecordBatch, all_indices: &[usize], row_idx: usize) -> Vec<u8> {
2188 format!("{:?}", extract_scalar_key(batch, all_indices, row_idx)).into_bytes()
2189}
2190
2191fn detect_shared_lineage(
2194 rule: &FixpointRulePlan,
2195 pre_fold_facts: &[RecordBatch],
2196 tracker: &Arc<ProvenanceStore>,
2197 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2198 semiring_kind: SemiringKind,
2199) -> Option<SharedLineageInfo> {
2200 use uni_locy::{RuntimeWarning, RuntimeWarningCode};
2201
2202 let has_prob_fold = rule
2207 .fold_bindings
2208 .iter()
2209 .any(|fb| fb.aggregate.is_probability_aggregate());
2210 if !has_prob_fold {
2211 return None;
2212 }
2213
2214 let key_indices = &rule.key_column_indices;
2216 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2217
2218 let mut groups: HashMap<Vec<ScalarKey>, Vec<Vec<u8>>> = HashMap::new();
2219 for batch in pre_fold_facts {
2220 for row_idx in 0..batch.num_rows() {
2221 let key = extract_scalar_key(batch, key_indices, row_idx);
2222 let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
2223 groups.entry(key).or_default().push(fact_hash);
2224 }
2225 }
2226
2227 let mut shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>> = HashMap::new();
2228 let mut any_shared = false;
2229
2230 for (key, fact_hashes) in &groups {
2232 if fact_hashes.len() < 2 {
2233 continue;
2234 }
2235
2236 let mut has_inputs = false;
2238 let mut per_row_bases: Vec<HashSet<Vec<u8>>> = Vec::new();
2239 for fh in fact_hashes {
2240 let bases = compute_lineage(fh, tracker, &mut HashSet::new());
2241 if let Some(entry) = tracker.lookup(fh)
2242 && !entry.support.is_empty()
2243 {
2244 has_inputs = true;
2245 }
2246 per_row_bases.push(bases);
2247 }
2248
2249 let shared_found = if has_inputs {
2250 let mut found = false;
2252 'outer: for i in 0..per_row_bases.len() {
2253 for j in (i + 1)..per_row_bases.len() {
2254 if !per_row_bases[i].is_disjoint(&per_row_bases[j]) {
2255 found = true;
2256 break 'outer;
2257 }
2258 }
2259 }
2260 found
2261 } else {
2262 fact_hashes.iter().any(|fh| {
2265 tracker.lookup(fh).is_some_and(|entry| {
2266 rule.clauses
2267 .get(entry.clause_index)
2268 .is_some_and(|clause| clause.is_ref_bindings.iter().any(|b| !b.negated))
2269 })
2270 })
2271 };
2272
2273 if shared_found {
2274 any_shared = true;
2275 let rows: Vec<SharedGroupRow> = fact_hashes
2277 .iter()
2278 .zip(per_row_bases)
2279 .map(|(fh, bases)| SharedGroupRow {
2280 fact_hash: fh.clone(),
2281 lineage: bases,
2282 })
2283 .collect();
2284 shared_groups.insert(key.clone(), rows);
2285 }
2286 }
2287
2288 {
2294 let mut input_to_groups: HashMap<Vec<u8>, HashSet<Vec<ScalarKey>>> = HashMap::new();
2295 for (key, fact_hashes) in &groups {
2296 for fh in fact_hashes {
2297 if let Some(entry) = tracker.lookup(fh) {
2298 for input in &entry.support {
2299 input_to_groups
2300 .entry(input.base_fact_id.clone())
2301 .or_default()
2302 .insert(key.clone());
2303 }
2304 }
2305 }
2306 }
2307 let has_cross_group = input_to_groups.values().any(|g| g.len() > 1);
2308 if has_cross_group && let Ok(mut warnings) = warnings_slot.write() {
2309 let already_warned = warnings.iter().any(|w| {
2310 w.code == RuntimeWarningCode::CrossGroupCorrelationNotExact
2311 && w.rule_name == rule.name
2312 });
2313 if !already_warned {
2314 let example =
2318 input_to_groups
2319 .iter()
2320 .find(|(_, g)| g.len() > 1)
2321 .map(|(input, groups)| {
2322 let short = input
2323 .iter()
2324 .take(8)
2325 .map(|b| format!("{:02x}", b))
2326 .collect::<String>();
2327 let mut group_strs: Vec<String> =
2328 groups.iter().map(|k| format!("{:?}", k)).collect();
2329 group_strs.sort();
2330 format!(
2331 "input {} shared by groups [{}]",
2332 short,
2333 group_strs.join(", ")
2334 )
2335 });
2336 let shared_variable_count =
2342 input_to_groups.values().filter(|g| g.len() > 1).count();
2343 warnings.push(RuntimeWarning {
2344 code: RuntimeWarningCode::CrossGroupCorrelationNotExact,
2345 message: format!(
2346 "Rule '{}': {} IS-ref base fact(s) are shared across different \
2347 KEY groups. BDD corrects per-group probabilities but cannot \
2348 account for cross-group correlations.",
2349 rule.name, shared_variable_count
2350 ),
2351 rule_name: rule.name.clone(),
2352 variable_count: Some(shared_variable_count),
2353 key_group: example,
2354 });
2355 }
2356 }
2357 }
2358
2359 if any_shared {
2360 let suppress_under_topk = matches!(semiring_kind, SemiringKind::TopKProofs { .. });
2370 if !suppress_under_topk && let Ok(mut warnings) = warnings_slot.write() {
2371 let already_warned = warnings.iter().any(|w| {
2372 w.code == RuntimeWarningCode::SharedProbabilisticDependency
2373 && w.rule_name == rule.name
2374 });
2375 if !already_warned {
2376 warnings.push(RuntimeWarning {
2377 code: RuntimeWarningCode::SharedProbabilisticDependency,
2378 message: format!(
2379 "Rule '{}' aggregates with MNOR/MPROD but some proof paths \
2380 share intermediate facts, violating the independence assumption. \
2381 Results may overestimate probability.",
2382 rule.name
2383 ),
2384 rule_name: rule.name.clone(),
2385 variable_count: None,
2386 key_group: None,
2387 });
2388 }
2389 }
2390 Some(SharedLineageInfo { shared_groups })
2391 } else {
2392 None
2393 }
2394}
2395
2396#[allow(
2404 clippy::too_many_arguments,
2405 reason = "context bundle would be over-engineering for one call site"
2406)]
2407pub(crate) async fn record_and_detect_lineage_nonrecursive(
2408 rule: &FixpointRulePlan,
2409 tagged_clause_facts: &[(usize, Vec<RecordBatch>)],
2410 tracker: &Arc<ProvenanceStore>,
2411 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2412 registry: &Arc<DerivedScanRegistry>,
2413 top_k_proofs: usize,
2414 classifiers: ClassifierRefs<'_>,
2415 semiring_kind: SemiringKind,
2416) -> Option<SharedLineageInfo> {
2417 let classifier_registry = classifiers.registry;
2418 let classifier_cache = classifiers.cache;
2419 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2420
2421 let base_probs = if top_k_proofs > 0 {
2423 tracker.base_fact_probs()
2424 } else {
2425 HashMap::new()
2426 };
2427
2428 let mut topk_acc = TopKProofAccumulator::new();
2429
2430 for (clause_index, batches) in tagged_clause_facts {
2432 for batch in batches {
2433 for row_idx in 0..batch.num_rows() {
2434 let row_hash = fact_hash_key(batch, &all_indices, row_idx);
2435 let fact_row = batch_row_to_value_map(batch, row_idx);
2436
2437 let support = collect_is_ref_inputs(rule, *clause_index, batch, row_idx, registry);
2438
2439 let proof_probability = if top_k_proofs > 0 {
2440 compute_proof_probability(&support, &base_probs)
2441 } else {
2442 None
2443 };
2444
2445 let entry = ProvenanceAnnotation {
2446 rule_name: rule.name.clone(),
2447 clause_index: *clause_index,
2448 support,
2449 along_values: {
2450 let along_names: Vec<String> = rule
2451 .clauses
2452 .get(*clause_index)
2453 .map(|c| c.along_bindings.clone())
2454 .unwrap_or_default();
2455 along_names
2456 .iter()
2457 .filter_map(|name| {
2458 fact_row.get(name).map(|v| (name.clone(), v.clone()))
2459 })
2460 .collect()
2461 },
2462 iteration: 0,
2463 fact_row: fact_row.clone(),
2464 proof_probability,
2465 neural_calls: collect_neural_calls_for_row(
2466 rule,
2467 *clause_index,
2468 &fact_row,
2469 classifier_registry,
2470 classifier_cache,
2471 classifiers.provenance_store,
2472 )
2473 .await,
2474 };
2475 if top_k_proofs > 0 {
2476 topk_acc.accumulate(&entry, &row_hash);
2477 tracker.record_top_k(row_hash, entry, top_k_proofs);
2478 } else {
2479 tracker.record(row_hash, entry);
2480 }
2481 }
2482 }
2483 }
2484
2485 topk_acc.emit_warning_if_any(rule, top_k_proofs, warnings_slot);
2486
2487 let all_facts: Vec<RecordBatch> = tagged_clause_facts
2489 .iter()
2490 .flat_map(|(_, batches)| batches.iter().cloned())
2491 .collect();
2492 detect_shared_lineage(rule, &all_facts, tracker, warnings_slot, semiring_kind)
2493}
2494
2495pub(crate) fn apply_exact_wmc(
2503 pre_fold_facts: Vec<RecordBatch>,
2504 rule: &FixpointRulePlan,
2505 shared_info: &SharedLineageInfo,
2506 tracker: &Arc<ProvenanceStore>,
2507 max_bdd_variables: usize,
2508 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2509 approximate_slot: &Arc<StdRwLock<HashMap<String, Vec<String>>>>,
2510) -> DFResult<Vec<RecordBatch>> {
2511 use crate::query::df_graph::locy_bdd::{SemiringOp, weighted_model_count};
2512 use uni_locy::{RuntimeWarning, RuntimeWarningCode};
2513
2514 let prob_fold = rule
2518 .fold_bindings
2519 .iter()
2520 .find(|fb| fb.aggregate.is_probability_aggregate());
2521 let prob_fold = match prob_fold {
2522 Some(f) => f,
2523 None => return Ok(pre_fold_facts),
2524 };
2525 let semiring_op = if prob_fold.aggregate.is_noisy_or() {
2526 SemiringOp::Disjunction
2527 } else {
2528 SemiringOp::Conjunction
2529 };
2530 let prob_col_idx = prob_fold.input_col_index;
2531 let prob_col_name = rule.yield_schema.field(prob_col_idx).name().clone();
2532
2533 let key_indices = &rule.key_column_indices;
2534 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2535
2536 let shared_keys: HashSet<Vec<ScalarKey>> = shared_info.shared_groups.keys().cloned().collect();
2538
2539 struct GroupAccum {
2542 base_facts: Vec<HashSet<Vec<u8>>>,
2543 base_probs: HashMap<Vec<u8>, f64>,
2544 representative: (usize, usize),
2546 row_locations: Vec<(usize, usize)>,
2547 }
2548
2549 let mut group_accums: HashMap<Vec<ScalarKey>, GroupAccum> = HashMap::new();
2550 let mut non_shared_rows: Vec<(usize, usize)> = Vec::new(); for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
2553 for row_idx in 0..batch.num_rows() {
2554 let key = extract_scalar_key(batch, key_indices, row_idx);
2555 if shared_keys.contains(&key) {
2556 let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
2557 let bases = compute_lineage(&fact_hash, tracker, &mut HashSet::new());
2558
2559 let accum = group_accums.entry(key).or_insert_with(|| GroupAccum {
2560 base_facts: Vec::new(),
2561 base_probs: HashMap::new(),
2562 representative: (batch_idx, row_idx),
2563 row_locations: Vec::new(),
2564 });
2565
2566 for bf in &bases {
2568 if !accum.base_probs.contains_key(bf)
2569 && let Some(entry) = tracker.lookup(bf)
2570 && let Some(val) = entry.fact_row.get(&prob_col_name)
2571 && let Some(p) = value_to_f64(val)
2572 {
2573 accum.base_probs.insert(bf.clone(), p);
2574 }
2575 }
2576
2577 accum.base_facts.push(bases);
2578 accum.row_locations.push((batch_idx, row_idx));
2579 } else {
2580 non_shared_rows.push((batch_idx, row_idx));
2581 }
2582 }
2583 }
2584
2585 let mut keep_rows: HashSet<(usize, usize)> = HashSet::new();
2588 let mut overrides: HashMap<(usize, usize), f64> = HashMap::new();
2590
2591 for &loc in &non_shared_rows {
2593 keep_rows.insert(loc);
2594 }
2595
2596 for (key, accum) in &group_accums {
2597 let bdd_result = weighted_model_count(
2598 &accum.base_facts,
2599 &accum.base_probs,
2600 semiring_op,
2601 max_bdd_variables,
2602 );
2603
2604 if bdd_result.approximated {
2605 if let Ok(mut warnings) = warnings_slot.write() {
2607 let key_desc = format!("{key:?}");
2608 let already_warned = warnings.iter().any(|w| {
2609 w.code == RuntimeWarningCode::BddLimitExceeded
2610 && w.rule_name == rule.name
2611 && w.key_group.as_deref() == Some(&key_desc)
2612 });
2613 if !already_warned {
2614 warnings.push(RuntimeWarning {
2615 code: RuntimeWarningCode::BddLimitExceeded,
2616 message: format!(
2617 "Rule '{}': BDD variable limit exceeded ({} > {}). \
2618 Falling back to independence-mode result.",
2619 rule.name, bdd_result.variable_count, max_bdd_variables
2620 ),
2621 rule_name: rule.name.clone(),
2622 variable_count: Some(bdd_result.variable_count),
2623 key_group: Some(key_desc),
2624 });
2625 }
2626 }
2627 if let Ok(mut approx) = approximate_slot.write() {
2628 let key_desc = format!("{key:?}");
2629 approx.entry(rule.name.clone()).or_default().push(key_desc);
2630 }
2631 for &loc in &accum.row_locations {
2633 keep_rows.insert(loc);
2634 }
2635 } else {
2636 keep_rows.insert(accum.representative);
2638 overrides.insert(accum.representative, bdd_result.probability);
2639 }
2640 }
2641
2642 let mut result_batches = Vec::new();
2644 for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
2645 let kept_indices: Vec<usize> = (0..batch.num_rows())
2646 .filter(|&row_idx| keep_rows.contains(&(batch_idx, row_idx)))
2647 .collect();
2648
2649 if kept_indices.is_empty() {
2650 continue;
2651 }
2652
2653 let indices = arrow::array::UInt32Array::from(
2654 kept_indices.iter().map(|&i| i as u32).collect::<Vec<_>>(),
2655 );
2656 let mut columns: Vec<arrow::array::ArrayRef> = batch
2657 .columns()
2658 .iter()
2659 .map(|col| arrow::compute::take(col, &indices, None))
2660 .collect::<Result<Vec<_>, _>>()
2661 .map_err(arrow_err)?;
2662
2663 let override_map: Vec<Option<f64>> = kept_indices
2665 .iter()
2666 .map(|&row_idx| overrides.get(&(batch_idx, row_idx)).copied())
2667 .collect();
2668
2669 if override_map.iter().any(|o| o.is_some()) && prob_col_idx < columns.len() {
2670 let existing_prob = columns[prob_col_idx]
2672 .as_any()
2673 .downcast_ref::<arrow::array::Float64Array>();
2674 let new_values: Vec<f64> = override_map
2675 .iter()
2676 .enumerate()
2677 .map(|(i, ov)| match ov {
2678 Some(p) => *p,
2679 None => existing_prob.map(|arr| arr.value(i)).unwrap_or(0.0),
2680 })
2681 .collect();
2682 columns[prob_col_idx] = Arc::new(arrow::array::Float64Array::from(new_values));
2683 }
2684
2685 let result_batch = RecordBatch::try_new(batch.schema(), columns).map_err(arrow_err)?;
2686 result_batches.push(result_batch);
2687 }
2688
2689 Ok(result_batches)
2690}
2691
2692fn value_to_f64(val: &uni_common::Value) -> Option<f64> {
2694 match val {
2695 uni_common::Value::Float(f) => Some(*f),
2696 uni_common::Value::Int(i) => Some(*i as f64),
2697 _ => None,
2698 }
2699}
2700
2701fn compute_lineage(
2708 fact_hash: &[u8],
2709 tracker: &Arc<ProvenanceStore>,
2710 visited: &mut HashSet<Vec<u8>>,
2711) -> HashSet<Vec<u8>> {
2712 if !visited.insert(fact_hash.to_vec()) {
2713 return HashSet::new(); }
2715
2716 match tracker.lookup(fact_hash) {
2717 Some(entry) if !entry.support.is_empty() => {
2718 let mut bases = HashSet::new();
2719 for input in &entry.support {
2720 let child_bases = compute_lineage(&input.base_fact_id, tracker, visited);
2721 bases.extend(child_bases);
2722 }
2723 bases
2724 }
2725 _ => {
2726 let mut set = HashSet::new();
2728 set.insert(fact_hash.to_vec());
2729 set
2730 }
2731 }
2732}
2733
2734fn find_clause_for_row(
2739 delta_batch: &RecordBatch,
2740 row_idx: usize,
2741 all_indices: &[usize],
2742 clause_candidates: &[Vec<RecordBatch>],
2743) -> usize {
2744 let target_key = extract_scalar_key(delta_batch, all_indices, row_idx);
2745 for (clause_idx, batches) in clause_candidates.iter().enumerate() {
2746 for batch in batches {
2747 if batch.num_columns() != all_indices.len() {
2748 continue;
2749 }
2750 for r in 0..batch.num_rows() {
2751 if extract_scalar_key(batch, all_indices, r) == target_key {
2752 return clause_idx;
2753 }
2754 }
2755 }
2756 }
2757 0
2758}
2759
2760fn batch_row_to_value_map(
2762 batch: &RecordBatch,
2763 row_idx: usize,
2764) -> std::collections::HashMap<String, Value> {
2765 use uni_store::storage::arrow_convert::arrow_to_value;
2766
2767 let schema = batch.schema();
2768 schema
2769 .fields()
2770 .iter()
2771 .enumerate()
2772 .map(|(col_idx, field)| {
2773 let col = batch.column(col_idx);
2774 let val = arrow_to_value(col.as_ref(), row_idx, None);
2775 (field.name().clone(), val)
2776 })
2777 .collect()
2778}
2779
2780pub fn apply_anti_join(
2785 batches: Vec<RecordBatch>,
2786 neg_facts: &[RecordBatch],
2787 left_col: &str,
2788 right_col: &str,
2789) -> datafusion::error::Result<Vec<RecordBatch>> {
2790 use arrow::compute::filter_record_batch;
2791 use arrow_array::{Array as _, BooleanArray, UInt64Array};
2792
2793 let mut banned: std::collections::HashSet<u64> = std::collections::HashSet::new();
2795 for batch in neg_facts {
2796 let Ok(idx) = batch.schema().index_of(right_col) else {
2797 continue;
2798 };
2799 let arr = batch.column(idx);
2800 let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2801 continue;
2802 };
2803 for i in 0..vids.len() {
2804 if !vids.is_null(i) {
2805 banned.insert(vids.value(i));
2806 }
2807 }
2808 }
2809
2810 if banned.is_empty() {
2811 return Ok(batches);
2812 }
2813
2814 let mut result = Vec::new();
2816 for batch in batches {
2817 let Ok(idx) = batch.schema().index_of(left_col) else {
2818 result.push(batch);
2819 continue;
2820 };
2821 let arr = batch.column(idx);
2822 let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2823 result.push(batch);
2824 continue;
2825 };
2826 let keep: Vec<bool> = (0..vids.len())
2827 .map(|i| vids.is_null(i) || !banned.contains(&vids.value(i)))
2828 .collect();
2829 let keep_arr = BooleanArray::from(keep);
2830 let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
2831 if filtered.num_rows() > 0 {
2832 result.push(filtered);
2833 }
2834 }
2835 Ok(result)
2836}
2837
2838#[allow(clippy::too_many_arguments)]
2859pub(crate) async fn apply_model_invocations(
2860 batches: Vec<RecordBatch>,
2861 invocations: &[uni_locy::ModelInvocation],
2862 registry: &Arc<ClassifierRegistry>,
2863 cache: Option<&Arc<uni_locy::ModelInvocationCache>>,
2864 provenance_store: Option<&Arc<uni_locy::NeuralProvenanceStore>>,
2865 path_context_handles: &HashMap<
2866 String,
2867 crate::query::df_graph::locy_model_invoke::PathContextHandle,
2868 >,
2869 xervo_runtime: &crate::query::df_graph::locy_model_invoke::XervoRuntimeHandle,
2870 graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
2871) -> DFResult<Vec<RecordBatch>> {
2872 use uni_locy::ClassifyInput;
2873 if batches.is_empty() || invocations.is_empty() {
2874 return Ok(batches);
2875 }
2876 let semantic_match_embeddings =
2880 pre_embed_semantic_match_queries(invocations, xervo_runtime).await?;
2881 let graph_feature_maps = precompute_graph_feature_maps(invocations, graph_algo).await?;
2886 let neighbor_feature_maps =
2887 precompute_neighbor_feature_maps(invocations, &batches, graph_algo).await?;
2888 let mut out_batches = Vec::with_capacity(batches.len());
2889 for batch in batches {
2890 let mut current = batch;
2891 for invocation in invocations {
2892 let classifier = registry.get(&invocation.model_name).ok_or_else(|| {
2893 datafusion::error::DataFusionError::Execution(format!(
2894 "neural classifier '{}' not registered; \
2895 add it to LocyConfig::classifier_registry",
2896 invocation.model_name
2897 ))
2898 })?;
2899
2900 let resolvers = build_feature_resolvers(
2912 ¤t,
2913 invocation,
2914 path_context_handles,
2915 &semantic_match_embeddings,
2916 &graph_feature_maps,
2917 &neighbor_feature_maps,
2918 )?;
2919
2920 let n_rows = current.num_rows();
2922 let mut inputs: Vec<ClassifyInput> = Vec::with_capacity(n_rows);
2923 let mut input_hashes: Vec<u64> = Vec::with_capacity(n_rows);
2924 for row_idx in 0..n_rows {
2925 let mut features = std::collections::HashMap::new();
2926 for resolver in &resolvers {
2927 let value = resolver.eval_row(¤t, row_idx)?;
2928 features.insert(resolver.binding_name.clone(), value);
2929 }
2930 let input = ClassifyInput { features };
2931 input_hashes.push(input.stable_hash());
2932 inputs.push(input);
2933 }
2934
2935 let mut probs: Vec<f64> = vec![0.0; n_rows];
2939 let mut miss_inputs: Vec<ClassifyInput> = Vec::new();
2940 let mut miss_row_indices: Vec<usize> = Vec::new();
2941 if let Some(c) = cache {
2942 for (row_idx, h) in input_hashes.iter().enumerate() {
2943 match c.get(&invocation.model_name, *h) {
2944 Some(v) => probs[row_idx] = v,
2945 None => {
2946 miss_row_indices.push(row_idx);
2947 miss_inputs.push(inputs[row_idx].clone());
2948 }
2949 }
2950 }
2951 } else {
2952 miss_row_indices = (0..n_rows).collect();
2953 miss_inputs = inputs.clone();
2954 }
2955
2956 let calibrator = classifier.get_calibrator();
2965 let (miss_raws, miss_calibrated) = if miss_inputs.is_empty() {
2966 (Vec::new(), Vec::new())
2967 } else if calibrator.is_some() {
2968 let pairs = classifier
2969 .raw_and_calibrated(&miss_inputs)
2970 .await
2971 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2972 if pairs.len() != miss_inputs.len() {
2973 return Err(datafusion::error::DataFusionError::Execution(format!(
2974 "classifier '{}' raw_and_calibrated returned {} outputs for {} inputs",
2975 invocation.model_name,
2976 pairs.len(),
2977 miss_inputs.len()
2978 )));
2979 }
2980 let raws: Vec<f64> = pairs.iter().map(|(r, _)| *r).collect();
2981 let cals: Vec<f64> = pairs.iter().map(|(r, c)| c.unwrap_or(*r)).collect();
2982 (raws, cals)
2983 } else {
2984 let r = classifier
2985 .classify(&miss_inputs)
2986 .await
2987 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2988 if r.len() != miss_inputs.len() {
2989 return Err(datafusion::error::DataFusionError::Execution(format!(
2990 "classifier '{}' returned {} outputs for {} inputs",
2991 invocation.model_name,
2992 r.len(),
2993 miss_inputs.len()
2994 )));
2995 }
2996 (r.clone(), r)
2998 };
2999 let mut row_raw: Vec<Option<f64>> = vec![None; n_rows];
3009 for (i, &row_idx) in miss_row_indices.iter().enumerate() {
3010 probs[row_idx] = miss_calibrated[i];
3011 row_raw[row_idx] = Some(miss_raws[i]);
3012 if let Some(c) = cache {
3013 c.insert(
3014 &invocation.model_name,
3015 input_hashes[row_idx],
3016 miss_calibrated[i],
3017 );
3018 }
3019 }
3020
3021 if let Some(store) = provenance_store {
3029 for row_idx in 0..n_rows {
3030 let calibrated_value = probs[row_idx];
3031 let (raw_value, calibrated) = match (row_raw[row_idx], &calibrator) {
3032 (Some(raw), Some(_)) => (raw, Some(calibrated_value)),
3033 (Some(raw), None) => (raw, None),
3034 (None, _) => (
3039 calibrated_value,
3040 calibrator.as_ref().map(|_| calibrated_value),
3041 ),
3042 };
3043 let band = calibrator
3044 .as_ref()
3045 .and_then(|c| c.confidence_band(calibrated_value));
3046 store.record(
3047 &invocation.model_name,
3048 input_hashes[row_idx],
3049 uni_locy::NeuralProvenanceRecord {
3050 raw_probability: raw_value,
3051 calibrated_probability: calibrated,
3052 confidence_band: band,
3053 feature_inputs: inputs[row_idx].features.clone(),
3061 },
3062 );
3063 }
3064 }
3065
3066 let out_col: Arc<dyn arrow_array::Array> =
3071 Arc::new(arrow_array::Float64Array::from(probs));
3072 let schema = current.schema();
3073 let target_idx = schema.index_of(&invocation.output_column).ok();
3074 let mut columns: Vec<Arc<dyn arrow_array::Array>> = current.columns().to_vec();
3075 let mut fields: Vec<Arc<arrow_schema::Field>> =
3076 schema.fields().iter().cloned().collect();
3077 match target_idx {
3078 Some(idx) => {
3079 columns[idx] = out_col;
3080 fields[idx] = Arc::new(arrow_schema::Field::new(
3083 &invocation.output_column,
3084 arrow_schema::DataType::Float64,
3085 true,
3086 ));
3087 }
3088 None => {
3089 columns.push(out_col);
3090 fields.push(Arc::new(arrow_schema::Field::new(
3091 &invocation.output_column,
3092 arrow_schema::DataType::Float64,
3093 true,
3094 )));
3095 }
3096 }
3097 let new_schema = Arc::new(arrow_schema::Schema::new(fields));
3098 current = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
3099 }
3100 out_batches.push(current);
3101 }
3102 Ok(out_batches)
3103}
3104
3105struct FeatureResolver {
3113 binding_name: String,
3114 kind: FeatureResolverKind,
3115}
3116
3117enum FeatureResolverKind {
3118 Direct(usize),
3119 SimilarTo {
3120 left: FeatureValueSrc,
3121 right: FeatureValueSrc,
3122 },
3123 PathContext {
3128 subject_col: usize,
3129 vid_to_value: Arc<HashMap<u64, uni_locy::FeatureValue>>,
3130 },
3131 GraphAlgoScore {
3136 subject_col: usize,
3137 vid_to_score: Arc<HashMap<u64, f64>>,
3138 },
3139 NeighborAggregate {
3145 subject_col: usize,
3146 op: NeighborAgg,
3147 vid_to_values: Arc<HashMap<u64, Vec<f64>>>,
3148 },
3149}
3150
3151#[derive(Debug, Clone, Copy)]
3152enum NeighborAgg {
3153 Avg,
3154 Max,
3155 Sum,
3156}
3157
3158#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3162enum NeighborDirection {
3163 Outgoing,
3164 Incoming,
3165 Both,
3166}
3167
3168impl NeighborDirection {
3169 fn store_directions(self) -> &'static [uni_store::storage::direction::Direction] {
3170 use uni_store::storage::direction::Direction;
3171 match self {
3172 NeighborDirection::Outgoing => &[Direction::Outgoing],
3173 NeighborDirection::Incoming => &[Direction::Incoming],
3174 NeighborDirection::Both => &[Direction::Outgoing, Direction::Incoming],
3175 }
3176 }
3177}
3178
3179impl NeighborAgg {
3180 fn from_fn_name(name: &str) -> Option<Self> {
3181 match name {
3182 "avg_neighbor" => Some(NeighborAgg::Avg),
3183 "max_neighbor" => Some(NeighborAgg::Max),
3184 "sum_neighbor" => Some(NeighborAgg::Sum),
3185 _ => None,
3186 }
3187 }
3188
3189 fn apply(self, values: &[f64]) -> Option<f64> {
3190 if values.is_empty() {
3191 return None;
3192 }
3193 match self {
3194 NeighborAgg::Avg => Some(values.iter().sum::<f64>() / values.len() as f64),
3195 NeighborAgg::Max => values.iter().copied().reduce(f64::max),
3196 NeighborAgg::Sum => Some(values.iter().sum()),
3197 }
3198 }
3199}
3200
3201enum FeatureValueSrc {
3204 Col(usize),
3205 Const(uni_common::Value),
3206}
3207
3208impl FeatureValueSrc {
3209 fn resolve(&self, batch: &RecordBatch, row_idx: usize) -> uni_common::Value {
3210 match self {
3211 FeatureValueSrc::Col(idx) => extract_common_value(batch.column(*idx).as_ref(), row_idx),
3212 FeatureValueSrc::Const(v) => v.clone(),
3213 }
3214 }
3215}
3216
3217impl FeatureResolver {
3218 fn eval_row(&self, batch: &RecordBatch, row_idx: usize) -> DFResult<uni_locy::FeatureValue> {
3219 match &self.kind {
3220 FeatureResolverKind::Direct(idx) => {
3221 Ok(extract_feature_value(batch.column(*idx).as_ref(), row_idx))
3222 }
3223 FeatureResolverKind::SimilarTo { left, right } => {
3224 let lv = left.resolve(batch, row_idx);
3225 let rv = right.resolve(batch, row_idx);
3226 match crate::query::similar_to::eval_similar_to_pure(&lv, &rv) {
3227 Ok(uni_common::Value::Float(f)) => Ok(uni_locy::FeatureValue::Float(f)),
3228 Ok(_) => Ok(uni_locy::FeatureValue::Null),
3229 Err(e) => Err(datafusion::error::DataFusionError::Execution(format!(
3230 "similar_to UDF failed: {e}"
3231 ))),
3232 }
3233 }
3234 FeatureResolverKind::PathContext {
3235 subject_col,
3236 vid_to_value,
3237 } => {
3238 let col = batch.column(*subject_col);
3239 if col.is_null(row_idx) {
3240 return Ok(uni_locy::FeatureValue::Null);
3241 }
3242 if let Some(arr) = col.as_any().downcast_ref::<arrow_array::UInt64Array>() {
3243 let vid = arr.value(row_idx);
3244 Ok(vid_to_value
3245 .get(&vid)
3246 .cloned()
3247 .unwrap_or(uni_locy::FeatureValue::Null))
3248 } else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3249 let vid = arr.value(row_idx) as u64;
3250 Ok(vid_to_value
3251 .get(&vid)
3252 .cloned()
3253 .unwrap_or(uni_locy::FeatureValue::Null))
3254 } else {
3255 Ok(uni_locy::FeatureValue::Null)
3256 }
3257 }
3258 FeatureResolverKind::GraphAlgoScore {
3259 subject_col,
3260 vid_to_score,
3261 } => {
3262 let col = batch.column(*subject_col);
3263 if col.is_null(row_idx) {
3264 return Ok(uni_locy::FeatureValue::Null);
3265 }
3266 let vid_opt: Option<u64> = if let Some(arr) =
3267 col.as_any().downcast_ref::<arrow_array::UInt64Array>()
3268 {
3269 Some(arr.value(row_idx))
3270 } else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3271 Some(arr.value(row_idx) as u64)
3272 } else {
3273 match extract_common_value(col.as_ref(), row_idx) {
3279 uni_common::Value::Node(n) => Some(n.vid.as_u64()),
3280 uni_common::Value::Int(i) => Some(i as u64),
3281 _ => None,
3282 }
3283 };
3284 Ok(vid_opt
3285 .and_then(|v| vid_to_score.get(&v).copied())
3286 .map(uni_locy::FeatureValue::Float)
3287 .unwrap_or(uni_locy::FeatureValue::Null))
3288 }
3289 FeatureResolverKind::NeighborAggregate {
3290 subject_col,
3291 op,
3292 vid_to_values,
3293 } => {
3294 let vid_opt = extract_vid_from_column(batch.column(*subject_col).as_ref(), row_idx);
3295 Ok(vid_opt
3296 .and_then(|v| vid_to_values.get(&v))
3297 .and_then(|values| op.apply(values))
3298 .map(uni_locy::FeatureValue::Float)
3299 .unwrap_or(uni_locy::FeatureValue::Null))
3300 }
3301 }
3302 }
3303}
3304
3305fn extract_vid_from_column(col: &dyn arrow_array::Array, row_idx: usize) -> Option<u64> {
3310 if col.is_null(row_idx) {
3311 return None;
3312 }
3313 if let Some(arr) = col.as_any().downcast_ref::<arrow_array::UInt64Array>() {
3314 return Some(arr.value(row_idx));
3315 }
3316 if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3317 return Some(arr.value(row_idx) as u64);
3318 }
3319 match extract_common_value(col, row_idx) {
3320 uni_common::Value::Node(n) => Some(n.vid.as_u64()),
3321 uni_common::Value::Int(i) => Some(i as u64),
3322 _ => None,
3323 }
3324}
3325
3326#[allow(clippy::too_many_arguments)]
3327fn build_feature_resolvers(
3328 batch: &RecordBatch,
3329 invocation: &uni_locy::ModelInvocation,
3330 path_context_handles: &HashMap<
3331 String,
3332 crate::query::df_graph::locy_model_invoke::PathContextHandle,
3333 >,
3334 semantic_match_embeddings: &HashMap<String, Vec<f32>>,
3335 graph_feature_maps: &HashMap<String, Arc<HashMap<u64, f64>>>,
3336 neighbor_feature_maps: &NeighborFeatureMaps,
3337) -> DFResult<Vec<FeatureResolver>> {
3338 use uni_cypher::ast::Expr;
3339 let schema = batch.schema();
3340 let lookup_col = |name_or_property: String| -> DFResult<usize> {
3341 schema.index_of(&name_or_property).map_err(|_| {
3342 datafusion::error::DataFusionError::Execution(format!(
3343 "feature column '{name_or_property}' not found in clause body output schema"
3344 ))
3345 })
3346 };
3347 let resolve_src = |expr: &Expr| -> DFResult<FeatureValueSrc> {
3352 match expr {
3353 Expr::Variable(name) => {
3354 let col = if schema.index_of(name).is_ok() {
3355 name.clone()
3356 } else {
3357 let vid_name = format!("{}._vid", name);
3358 if schema.index_of(&vid_name).is_ok() {
3359 vid_name
3360 } else {
3361 name.clone()
3362 }
3363 };
3364 Ok(FeatureValueSrc::Col(lookup_col(col)?))
3365 }
3366 Expr::Property(boxed, prop) if matches!(boxed.as_ref(), Expr::Variable(_)) => {
3367 let Expr::Variable(v) = boxed.as_ref() else {
3368 unreachable!()
3369 };
3370 let direct = format!("{}.{}", v, prop);
3371 let col = if schema.index_of(&direct).is_ok() {
3372 direct
3373 } else {
3374 format!("__feat_{}_{}", v, prop)
3375 };
3376 Ok(FeatureValueSrc::Col(lookup_col(col)?))
3377 }
3378 Expr::Literal(lit) => Ok(FeatureValueSrc::Const(lit.to_value())),
3379 Expr::List(items) => {
3380 let mut out = Vec::with_capacity(items.len());
3381 for it in items {
3382 out.push(match it {
3383 Expr::Literal(lit) => lit.to_value(),
3384 _ => uni_common::Value::Null,
3385 });
3386 }
3387 Ok(FeatureValueSrc::Const(uni_common::Value::List(out)))
3388 }
3389 other => Err(datafusion::error::DataFusionError::Execution(format!(
3390 "unsupported feature sub-expression: {other:?}"
3391 ))),
3392 }
3393 };
3394
3395 if let Some(pc) = &invocation.path_context {
3403 let handle = path_context_handles.get(&pc.source_rule).ok_or_else(|| {
3404 datafusion::error::DataFusionError::Execution(format!(
3405 "model '{}' path_context references rule '{}' but no DerivedScanHandle \
3406 was registered; this should never happen — the build_clause path \
3407 mints a handle for every distinct source_rule in the invocation set",
3408 invocation.model_name, pc.source_rule
3409 ))
3410 })?;
3411 let subject_col = schema
3412 .index_of(&format!("{}._vid", pc.subject_var))
3413 .or_else(|_| schema.index_of(&pc.subject_var))
3414 .map_err(|_| {
3415 datafusion::error::DataFusionError::Execution(format!(
3416 "model '{}' path_context: subject column '{}' (or '{0}._vid') not \
3417 in body batch schema",
3418 invocation.model_name, pc.subject_var
3419 ))
3420 })?;
3421 let vid_to_value =
3422 build_path_context_lookup(handle, &pc.subject_var, &pc.column, &invocation.model_name)?;
3423 return Ok(vec![FeatureResolver {
3424 binding_name: pc.column.clone(),
3425 kind: FeatureResolverKind::PathContext {
3426 subject_col,
3427 vid_to_value: Arc::new(vid_to_value),
3428 },
3429 }]);
3430 }
3431
3432 let mut out = Vec::with_capacity(invocation.feature_exprs.len());
3433 for (i, fexpr) in invocation.feature_exprs.iter().enumerate() {
3434 let binding_name = invocation.feature_names[i].clone();
3435 let kind = match fexpr {
3436 Expr::FunctionCall { name, args, .. } if name == "similar_to" => {
3437 if args.len() != 2 {
3438 return Err(datafusion::error::DataFusionError::Execution(format!(
3439 "similar_to expects 2 args, got {}",
3440 args.len()
3441 )));
3442 }
3443 FeatureResolverKind::SimilarTo {
3444 left: resolve_src(&args[0])?,
3445 right: resolve_src(&args[1])?,
3446 }
3447 }
3448 Expr::FunctionCall { name, args, .. } if name == "semantic_match" => {
3449 if args.len() != 2 {
3454 return Err(datafusion::error::DataFusionError::Execution(format!(
3455 "semantic_match expects 2 args, got {}",
3456 args.len()
3457 )));
3458 }
3459 let text = match &args[1] {
3460 Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3461 other => {
3462 return Err(datafusion::error::DataFusionError::Execution(format!(
3463 "semantic_match: 2nd arg must be a string literal, got {other:?}"
3464 )));
3465 }
3466 };
3467 let embedded = semantic_match_embeddings.get(&text).ok_or_else(|| {
3468 datafusion::error::DataFusionError::Execution(format!(
3469 "semantic_match: query text '{text}' was not pre-embedded. \
3470 This is a bug — `apply_model_invocations` should have \
3471 embedded all unique semantic_match texts up front. Most \
3472 likely the Xervo runtime is not configured (configure \
3473 via `LocyConfig::xervo_runtime` or its equivalent)."
3474 ))
3475 })?;
3476 let right_vec: Vec<f32> = embedded.clone();
3477 FeatureResolverKind::SimilarTo {
3478 left: resolve_src(&args[0])?,
3479 right: FeatureValueSrc::Const(uni_common::Value::Vector(right_vec)),
3480 }
3481 }
3482 Expr::FunctionCall { name, args, .. }
3483 if matches!(
3484 name.as_str(),
3485 "degree_centrality"
3486 | "pagerank_score"
3487 | "closeness_centrality"
3488 | "betweenness_centrality"
3489 | "eigenvector_centrality"
3490 | "harmonic_centrality"
3491 | "katz_centrality"
3492 ) =>
3493 {
3494 if args.len() != 1 {
3495 return Err(datafusion::error::DataFusionError::Execution(format!(
3496 "{name} expects 1 arg, got {}",
3497 args.len()
3498 )));
3499 }
3500 let Expr::Variable(v) = &args[0] else {
3501 return Err(datafusion::error::DataFusionError::Execution(format!(
3502 "{name}(...) argument must be a node variable, got {:?}",
3503 args[0]
3504 )));
3505 };
3506 let subject_col = {
3507 let direct = schema.index_of(v).ok();
3508 let vid_name = format!("{}._vid", v);
3509 let vid_col = schema.index_of(&vid_name).ok();
3510 vid_col.or(direct).ok_or_else(|| {
3511 datafusion::error::DataFusionError::Execution(format!(
3512 "{name}: subject column '{v}' (or '{v}._vid') not in body batch schema"
3513 ))
3514 })?
3515 };
3516 let vid_to_score = graph_feature_maps.get(name).cloned().ok_or_else(|| {
3517 datafusion::error::DataFusionError::Execution(format!(
3518 "{name}: pre-computed score map missing. This is a bug — \
3519 `apply_model_invocations` should have called \
3520 `precompute_graph_feature_maps` for every graph-structural \
3521 feature before building resolvers. Most likely the graph \
3522 algorithm registry is not configured."
3523 ))
3524 })?;
3525 FeatureResolverKind::GraphAlgoScore {
3526 subject_col,
3527 vid_to_score,
3528 }
3529 }
3530 Expr::FunctionCall { name, args, .. }
3531 if matches!(
3532 name.as_str(),
3533 "avg_neighbor" | "max_neighbor" | "sum_neighbor"
3534 ) =>
3535 {
3536 if args.len() != 3 && args.len() != 4 {
3537 return Err(datafusion::error::DataFusionError::Execution(format!(
3538 "{name} expects 3 or 4 args, got {}",
3539 args.len()
3540 )));
3541 }
3542 let Expr::Variable(v) = &args[0] else {
3543 return Err(datafusion::error::DataFusionError::Execution(format!(
3544 "{name}(...) first argument must be a node variable, got {:?}",
3545 args[0]
3546 )));
3547 };
3548 let rel_type = match &args[1] {
3549 Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3550 other => {
3551 return Err(datafusion::error::DataFusionError::Execution(format!(
3552 "{name}: 2nd arg must be a string literal (rel-type), got {other:?}"
3553 )));
3554 }
3555 };
3556 let prop_name = match &args[2] {
3557 Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3558 other => {
3559 return Err(datafusion::error::DataFusionError::Execution(format!(
3560 "{name}: 3rd arg must be a string literal (property), got {other:?}"
3561 )));
3562 }
3563 };
3564 let direction_arg = match args.get(3) {
3565 None => NeighborDirection::Outgoing,
3566 Some(Expr::Literal(uni_cypher::ast::CypherLiteral::String(d))) => {
3567 match d.to_uppercase().as_str() {
3568 "OUTGOING" => NeighborDirection::Outgoing,
3569 "INCOMING" => NeighborDirection::Incoming,
3570 "BOTH" => NeighborDirection::Both,
3571 other => {
3572 return Err(datafusion::error::DataFusionError::Execution(
3573 format!(
3574 "{name}: direction must be OUTGOING|INCOMING|BOTH, got '{other}'"
3575 ),
3576 ));
3577 }
3578 }
3579 }
3580 Some(other) => {
3581 return Err(datafusion::error::DataFusionError::Execution(format!(
3582 "{name}: 4th arg must be a string literal (direction), got {other:?}"
3583 )));
3584 }
3585 };
3586 let subject_col = {
3587 let direct = schema.index_of(v).ok();
3588 let vid_name = format!("{}._vid", v);
3589 let vid_col = schema.index_of(&vid_name).ok();
3590 vid_col.or(direct).ok_or_else(|| {
3591 datafusion::error::DataFusionError::Execution(format!(
3592 "{name}: subject column '{v}' (or '{v}._vid') not in body batch schema"
3593 ))
3594 })?
3595 };
3596 let vid_to_values = neighbor_feature_maps
3597 .get(&(rel_type.clone(), prop_name.clone(), direction_arg))
3598 .cloned()
3599 .ok_or_else(|| {
3600 datafusion::error::DataFusionError::Execution(format!(
3601 "{name}: pre-computed neighbor map missing for ({rel_type}, {prop_name}, {direction_arg:?}). \
3602 This is a bug — `apply_model_invocations` should have called \
3603 `precompute_neighbor_feature_maps` for every neighbor-aggregator \
3604 feature before building resolvers."
3605 ))
3606 })?;
3607 let op = NeighborAgg::from_fn_name(name).unwrap();
3608 FeatureResolverKind::NeighborAggregate {
3609 subject_col,
3610 op,
3611 vid_to_values,
3612 }
3613 }
3614 other => match resolve_src(other)? {
3615 FeatureValueSrc::Col(idx) => FeatureResolverKind::Direct(idx),
3616 FeatureValueSrc::Const(_) => {
3617 return Err(datafusion::error::DataFusionError::Execution(format!(
3618 "model '{}' feature must reference a variable or property — got a literal",
3619 invocation.model_name
3620 )));
3621 }
3622 },
3623 };
3624 out.push(FeatureResolver { binding_name, kind });
3625 }
3626 Ok(out)
3627}
3628
3629async fn pre_embed_semantic_match_queries(
3636 invocations: &[uni_locy::ModelInvocation],
3637 xervo_runtime: &crate::query::df_graph::locy_model_invoke::XervoRuntimeHandle,
3638) -> DFResult<HashMap<String, Vec<f32>>> {
3639 use uni_cypher::ast::{CypherLiteral, Expr};
3640 let mut needed: Vec<(String, String)> = Vec::new();
3650 for inv in invocations {
3651 let alias = inv
3652 .embedder_alias
3653 .clone()
3654 .unwrap_or_else(|| "default".to_string());
3655 for fexpr in &inv.feature_exprs {
3656 if let Expr::FunctionCall { name, args, .. } = fexpr
3657 && name == "semantic_match"
3658 && args.len() == 2
3659 && let Expr::Literal(CypherLiteral::String(s)) = &args[1]
3660 {
3661 let tuple = (s.clone(), alias.clone());
3662 if !needed.contains(&tuple) {
3663 needed.push(tuple);
3664 }
3665 }
3666 }
3667 }
3668 if needed.is_empty() {
3669 return Ok(HashMap::new());
3670 }
3671 let runtime = xervo_runtime.as_ref().ok_or_else(|| {
3672 datafusion::error::DataFusionError::Execution(
3673 "semantic_match: Uni-Xervo runtime not configured. Either provide \
3674 one via `LocyConfig::xervo_runtime` (or its equivalent setup \
3675 path) or pre-compute the query embedding and pass it via \
3676 `similar_to(prop, <literal_vector>)`."
3677 .to_string(),
3678 )
3679 })?;
3680 let mut by_alias: HashMap<String, Vec<String>> = HashMap::new();
3683 for (text, alias) in &needed {
3684 by_alias
3685 .entry(alias.clone())
3686 .or_default()
3687 .push(text.clone());
3688 }
3689 let mut out: HashMap<String, Vec<f32>> = HashMap::new();
3690 for (alias, texts) in by_alias {
3691 let embedder = runtime.embedding(&alias).await.map_err(|e| {
3692 datafusion::error::DataFusionError::Execution(format!(
3693 "semantic_match: failed to obtain embedder for alias '{alias}': {e}. \
3694 Register an embedder under that alias in your Uni-Xervo runtime, or \
3695 pre-compute the query embedding and pass via similar_to."
3696 ))
3697 })?;
3698 let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
3699 let embeddings = embedder
3700 .embed(&text_refs)
3701 .await
3702 .map_err(|e| {
3703 datafusion::error::DataFusionError::Execution(format!(
3704 "semantic_match: embedder '{alias}' call failed: {e}"
3705 ))
3706 })?
3707 .vectors;
3708 if embeddings.len() != texts.len() {
3709 return Err(datafusion::error::DataFusionError::Execution(format!(
3710 "semantic_match: embedder '{alias}' returned {} vectors for {} queries",
3711 embeddings.len(),
3712 texts.len()
3713 )));
3714 }
3715 for (text, vec) in texts.into_iter().zip(embeddings) {
3716 out.insert(text, vec);
3717 }
3718 }
3719 Ok(out)
3720}
3721
3722async fn precompute_graph_feature_maps(
3735 invocations: &[uni_locy::ModelInvocation],
3736 graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
3737) -> DFResult<HashMap<String, Arc<HashMap<u64, f64>>>> {
3738 use futures::StreamExt;
3739 use uni_algo::algo::procedures::AlgoContext;
3740 use uni_cypher::ast::Expr;
3741
3742 fn procedure_for(fn_name: &str) -> Option<&'static str> {
3745 match fn_name {
3746 "degree_centrality" => Some("uni.algo.degreeCentrality"),
3747 "pagerank_score" => Some("uni.algo.pageRank"),
3748 "closeness_centrality" => Some("uni.algo.closeness"),
3749 "betweenness_centrality" => Some("uni.algo.betweenness"),
3750 "eigenvector_centrality" => Some("uni.algo.eigenvectorCentrality"),
3751 "harmonic_centrality" => Some("uni.algo.harmonicCentrality"),
3752 "katz_centrality" => Some("uni.algo.katzCentrality"),
3753 _ => None,
3754 }
3755 }
3756
3757 let mut needed: Vec<String> = Vec::new();
3761 for inv in invocations {
3762 for fexpr in &inv.feature_exprs {
3763 if let Expr::FunctionCall { name, .. } = fexpr
3764 && procedure_for(name).is_some()
3765 && !needed.contains(name)
3766 {
3767 needed.push(name.clone());
3768 }
3769 }
3770 }
3771 if needed.is_empty() {
3772 return Ok(HashMap::new());
3773 }
3774
3775 let registry = graph_algo.registry.as_ref().ok_or_else(|| {
3776 datafusion::error::DataFusionError::Execution(
3777 "graph-structural FEATURE invoked but no `AlgorithmRegistry` is \
3778 configured. Configure one on `GraphExecutionContext::with_algo_registry`."
3779 .to_string(),
3780 )
3781 })?;
3782 let storage = graph_algo.storage.as_ref().ok_or_else(|| {
3783 datafusion::error::DataFusionError::Execution(
3784 "graph-structural FEATURE invoked but no storage handle was \
3785 threaded into the FEATURE runtime. This is a bug in df_planner."
3786 .to_string(),
3787 )
3788 })?;
3789
3790 let mut out: HashMap<String, Arc<HashMap<u64, f64>>> = HashMap::new();
3791 for fn_name in needed {
3792 let proc_name = procedure_for(&fn_name).unwrap();
3793 let procedure = registry.get(proc_name).ok_or_else(|| {
3794 datafusion::error::DataFusionError::Execution(format!(
3795 "graph-structural FEATURE '{fn_name}' resolves to procedure \
3796 '{proc_name}' which is not in the algorithm registry"
3797 ))
3798 })?;
3799 let args: Vec<serde_json::Value> = vec![
3804 serde_json::Value::Array(Vec::new()),
3805 serde_json::Value::Array(Vec::new()),
3806 ];
3807 let algo_ctx = AlgoContext::new(
3808 storage.clone(),
3809 graph_algo.l0_manager.as_ref().map(Arc::clone),
3810 );
3811 let filled_args = procedure
3832 .signature()
3833 .validate_args(args.clone())
3834 .map_err(|e| {
3835 datafusion::error::DataFusionError::Execution(format!(
3836 "graph-structural FEATURE '{fn_name}': argument validation failed: {e}"
3837 ))
3838 })?;
3839 let projection = uni_algo::algo::procedure_template::build_projection_from_direct_args(
3840 procedure.as_ref(),
3841 &algo_ctx,
3842 &filled_args,
3843 )
3844 .await
3845 .map_err(|e| {
3846 datafusion::error::DataFusionError::Execution(format!(
3847 "graph-structural FEATURE '{fn_name}': projection build failed: {e}"
3848 ))
3849 })?;
3850 let mut stream = procedure.execute_with_projection(algo_ctx, args, projection);
3851 let mut score_map: HashMap<u64, f64> = HashMap::new();
3852 let sig = procedure.signature();
3853 let node_idx = sig
3854 .yields
3855 .iter()
3856 .position(|(n, _)| *n == "nodeId")
3857 .ok_or_else(|| {
3858 datafusion::error::DataFusionError::Execution(format!(
3859 "procedure '{proc_name}' yield schema missing 'nodeId'"
3860 ))
3861 })?;
3862 let score_idx = sig
3867 .yields
3868 .iter()
3869 .position(|(n, _)| *n == "score" || *n == "centrality")
3870 .ok_or_else(|| {
3871 datafusion::error::DataFusionError::Execution(format!(
3872 "procedure '{proc_name}' yield schema missing a numeric score column \
3873 (expected 'score' or 'centrality')"
3874 ))
3875 })?;
3876 while let Some(row_res) = stream.next().await {
3877 let row = row_res.map_err(|e| {
3878 datafusion::error::DataFusionError::Execution(format!(
3879 "graph-structural FEATURE '{fn_name}': procedure '{proc_name}' failed: {e}"
3880 ))
3881 })?;
3882 let vid_v = row.values.get(node_idx);
3883 let score_v = row.values.get(score_idx);
3884 let (Some(vid_v), Some(score_v)) = (vid_v, score_v) else {
3885 continue;
3886 };
3887 let vid = vid_v.as_u64().or_else(|| vid_v.as_i64().map(|i| i as u64));
3888 let score = score_v
3889 .as_f64()
3890 .or_else(|| score_v.as_i64().map(|i| i as f64));
3891 if let (Some(vid), Some(score)) = (vid, score) {
3892 score_map.insert(vid, score);
3893 }
3894 }
3895 out.insert(fn_name, Arc::new(score_map));
3896 }
3897 Ok(out)
3898}
3899
3900type NeighborFeatureMaps =
3926 HashMap<(String, String, NeighborDirection), Arc<HashMap<u64, Vec<f64>>>>;
3927
3928async fn precompute_neighbor_feature_maps(
3929 invocations: &[uni_locy::ModelInvocation],
3930 batches: &[RecordBatch],
3931 graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
3932) -> DFResult<NeighborFeatureMaps> {
3933 use uni_cypher::ast::{CypherLiteral, Expr};
3934
3935 let parse_direction = |arg: Option<&Expr>| -> Option<NeighborDirection> {
3940 match arg {
3941 None => Some(NeighborDirection::Outgoing),
3942 Some(Expr::Literal(CypherLiteral::String(d))) => match d.to_uppercase().as_str() {
3943 "OUTGOING" => Some(NeighborDirection::Outgoing),
3944 "INCOMING" => Some(NeighborDirection::Incoming),
3945 "BOTH" => Some(NeighborDirection::Both),
3946 _ => None,
3947 },
3948 _ => None,
3949 }
3950 };
3951 let mut needed: Vec<(String, String, String, NeighborDirection)> = Vec::new();
3952 for inv in invocations {
3953 for fexpr in &inv.feature_exprs {
3954 if let Expr::FunctionCall { name, args, .. } = fexpr
3955 && NeighborAgg::from_fn_name(name).is_some()
3956 && (args.len() == 3 || args.len() == 4)
3957 && let Expr::Variable(v) = &args[0]
3958 && let Expr::Literal(CypherLiteral::String(rel)) = &args[1]
3959 && let Expr::Literal(CypherLiteral::String(prop)) = &args[2]
3960 && let Some(direction) = parse_direction(args.get(3))
3961 {
3962 let tuple = (v.clone(), rel.clone(), prop.clone(), direction);
3963 if !needed.contains(&tuple) {
3964 needed.push(tuple);
3965 }
3966 }
3967 }
3968 }
3969 if needed.is_empty() {
3970 return Ok(HashMap::new());
3971 }
3972
3973 let storage = graph_algo.storage.as_ref().ok_or_else(|| {
3974 datafusion::error::DataFusionError::Execution(
3975 "neighbor-aggregator FEATURE invoked but no storage handle was \
3976 threaded into the FEATURE runtime. This is a bug in df_planner."
3977 .to_string(),
3978 )
3979 })?;
3980 let property_manager = graph_algo.property_manager.as_ref().ok_or_else(|| {
3981 datafusion::error::DataFusionError::Execution(
3982 "neighbor-aggregator FEATURE invoked but no PropertyManager was \
3983 threaded into the FEATURE runtime. This is a bug in df_planner."
3984 .to_string(),
3985 )
3986 })?;
3987 let query_ctx = graph_algo.l0_buffers.as_ref().map(|bufs| {
3993 uni_store::runtime::context::QueryContext::new_with_pending(
3994 bufs.current.clone(),
3995 bufs.transaction.clone(),
3996 bufs.pending_flush.clone(),
3997 )
3998 });
3999
4000 let mut by_key: HashMap<(String, String, NeighborDirection), Vec<String>> = HashMap::new();
4004 for (subject_var, rel, prop, direction) in needed {
4005 by_key
4006 .entry((rel, prop, direction))
4007 .or_default()
4008 .push(subject_var);
4009 }
4010
4011 let mut out: NeighborFeatureMaps = HashMap::new();
4012 for ((rel_type, prop_name, direction), subject_vars) in by_key {
4013 let schema = storage.schema_manager().schema();
4015 let Some(edge_meta) = schema.edge_types.get(&rel_type) else {
4016 out.insert((rel_type, prop_name, direction), Arc::new(HashMap::new()));
4019 continue;
4020 };
4021 let edge_type_id = edge_meta.id;
4022
4023 let edge_ver = storage.get_edge_version_by_id(edge_type_id);
4026 for dir in direction.store_directions() {
4027 storage
4028 .warm_adjacency(edge_type_id, *dir, edge_ver)
4029 .await
4030 .map_err(|e| {
4031 datafusion::error::DataFusionError::Execution(format!(
4032 "neighbor-aggregator warm_adjacency for '{rel_type}' / {dir:?} failed: {e}"
4033 ))
4034 })?;
4035 }
4036
4037 let mut subject_vids: std::collections::HashSet<u64> = std::collections::HashSet::new();
4040 for subject_var in &subject_vars {
4041 for batch in batches {
4042 let schema = batch.schema();
4043 let col_idx = schema
4044 .index_of(&format!("{}._vid", subject_var))
4045 .ok()
4046 .or_else(|| schema.index_of(subject_var).ok());
4047 let Some(col_idx) = col_idx else { continue };
4048 let col = batch.column(col_idx);
4049 for row in 0..batch.num_rows() {
4050 if let Some(v) = extract_vid_from_column(col.as_ref(), row) {
4051 subject_vids.insert(v);
4052 }
4053 }
4054 }
4055 }
4056
4057 let mut vid_to_values: HashMap<u64, Vec<f64>> = HashMap::new();
4062 let adj = storage.adjacency_manager();
4063 for subject_vid in subject_vids {
4064 let mut neighbors: Vec<(uni_common::core::id::Vid, uni_common::core::id::Eid)> =
4065 Vec::new();
4066 for dir in direction.store_directions() {
4067 neighbors.extend(adj.get_neighbors(
4068 uni_common::core::id::Vid::from(subject_vid),
4069 edge_type_id,
4070 *dir,
4071 ));
4072 }
4073 let mut values: Vec<f64> = Vec::with_capacity(neighbors.len());
4074 for (neighbor_vid, _eid) in neighbors {
4075 let val = property_manager
4076 .get_vertex_prop_with_ctx(neighbor_vid, &prop_name, query_ctx.as_ref())
4077 .await
4078 .map_err(|e| {
4079 datafusion::error::DataFusionError::Execution(format!(
4080 "neighbor-aggregator: failed to read property \
4081 '{prop_name}' on neighbor vid {neighbor_vid:?}: {e}"
4082 ))
4083 })?;
4084 if let Some(f) = val.as_f64()
4085 && !f.is_nan()
4086 {
4087 values.push(f);
4088 }
4089 }
4090 vid_to_values.insert(subject_vid, values);
4091 }
4092 out.insert((rel_type, prop_name, direction), Arc::new(vid_to_values));
4093 }
4094 Ok(out)
4095}
4096
4097fn build_path_context_lookup(
4103 handle: &crate::query::df_graph::locy_model_invoke::PathContextHandle,
4104 _subject_var: &str,
4105 column: &str,
4106 model_name: &str,
4107) -> DFResult<HashMap<u64, uni_locy::FeatureValue>> {
4108 if handle.schema.fields().is_empty() {
4113 return Err(datafusion::error::DataFusionError::Execution(format!(
4114 "model '{model_name}' path_context: source rule has empty yield schema"
4115 )));
4116 }
4117 let subj_idx = 0_usize;
4118 let col_idx = handle.schema.index_of(column).map_err(|_| {
4119 datafusion::error::DataFusionError::Execution(format!(
4120 "model '{model_name}' path_context: column '{column}' not in \
4121 source rule's yield schema (have: {:?})",
4122 handle
4123 .schema
4124 .fields()
4125 .iter()
4126 .map(|f| f.name().clone())
4127 .collect::<Vec<_>>()
4128 ))
4129 })?;
4130 let batches = handle.data.read();
4131 let mut out: HashMap<u64, uni_locy::FeatureValue> = HashMap::new();
4132 for batch in batches.iter() {
4133 let subj_col = batch.column(subj_idx);
4134 let value_col = batch.column(col_idx);
4135 for row in 0..batch.num_rows() {
4136 if subj_col.is_null(row) {
4137 continue;
4138 }
4139 let vid = if let Some(a) = subj_col.as_any().downcast_ref::<arrow_array::UInt64Array>()
4140 {
4141 a.value(row)
4142 } else if let Some(a) = subj_col.as_any().downcast_ref::<arrow_array::Int64Array>() {
4143 a.value(row) as u64
4144 } else {
4145 continue;
4146 };
4147 let v = extract_feature_value(value_col.as_ref(), row);
4148 out.insert(vid, v);
4151 }
4152 }
4153 Ok(out)
4154}
4155
4156fn extract_common_value(col: &dyn arrow_array::Array, row_idx: usize) -> uni_common::Value {
4161 use arrow_array::{BooleanArray, Float64Array, Int64Array, LargeStringArray, StringArray};
4162 if col.is_null(row_idx) {
4163 return uni_common::Value::Null;
4164 }
4165 if let Some(a) = col.as_any().downcast_ref::<Float64Array>() {
4166 return uni_common::Value::Float(a.value(row_idx));
4167 }
4168 if let Some(a) = col.as_any().downcast_ref::<Int64Array>() {
4169 return uni_common::Value::Int(a.value(row_idx));
4170 }
4171 if let Some(a) = col.as_any().downcast_ref::<BooleanArray>() {
4172 return uni_common::Value::Bool(a.value(row_idx));
4173 }
4174 if let Some(a) = col.as_any().downcast_ref::<StringArray>() {
4175 return uni_common::Value::String(a.value(row_idx).to_string());
4176 }
4177 if let Some(a) = col.as_any().downcast_ref::<LargeStringArray>() {
4178 return uni_common::Value::String(a.value(row_idx).to_string());
4179 }
4180 if let Some(b) = col.as_any().downcast_ref::<arrow_array::LargeBinaryArray>() {
4181 let bytes = b.value(row_idx);
4182 if bytes.is_empty() {
4183 return uni_common::Value::Null;
4184 }
4185 return uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null);
4186 }
4187 uni_common::Value::Null
4188}
4189
4190fn extract_feature_value(col: &dyn arrow_array::Array, row_idx: usize) -> uni_locy::FeatureValue {
4191 use arrow_array::{BooleanArray, Float64Array, Int64Array, LargeStringArray, StringArray};
4192 if col.is_null(row_idx) {
4193 return uni_locy::FeatureValue::Null;
4194 }
4195 if let Some(a) = col.as_any().downcast_ref::<Float64Array>() {
4196 return uni_locy::FeatureValue::Float(a.value(row_idx));
4197 }
4198 if let Some(a) = col.as_any().downcast_ref::<Int64Array>() {
4199 return uni_locy::FeatureValue::Int(a.value(row_idx));
4200 }
4201 if let Some(a) = col.as_any().downcast_ref::<BooleanArray>() {
4202 return uni_locy::FeatureValue::Bool(a.value(row_idx));
4203 }
4204 if let Some(a) = col.as_any().downcast_ref::<StringArray>() {
4205 return uni_locy::FeatureValue::String(a.value(row_idx).to_string());
4206 }
4207 if let Some(a) = col.as_any().downcast_ref::<LargeStringArray>() {
4208 return uni_locy::FeatureValue::String(a.value(row_idx).to_string());
4209 }
4210 if let Some(b) = col.as_any().downcast_ref::<arrow_array::LargeBinaryArray>() {
4214 let bytes = b.value(row_idx);
4215 if bytes.is_empty() {
4216 return uni_locy::FeatureValue::Null;
4217 }
4218 let v = uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null);
4219 return match v {
4220 uni_common::Value::Float(f) => uni_locy::FeatureValue::Float(f),
4221 uni_common::Value::Int(i) => uni_locy::FeatureValue::Int(i),
4222 uni_common::Value::Bool(b) => uni_locy::FeatureValue::Bool(b),
4223 uni_common::Value::String(s) => uni_locy::FeatureValue::String(s),
4224 uni_common::Value::Null => uni_locy::FeatureValue::Null,
4225 _ => uni_locy::FeatureValue::Null,
4226 };
4227 }
4228 uni_locy::FeatureValue::Null
4229}
4230
4231pub fn apply_prob_complement(
4238 batches: Vec<RecordBatch>,
4239 neg_facts: &[RecordBatch],
4240 left_col: &str,
4241 right_col: &str,
4242 prob_col: &str,
4243 complement_col_name: &str,
4244) -> datafusion::error::Result<Vec<RecordBatch>> {
4245 use arrow_array::{Array as _, Float64Array, UInt64Array};
4246
4247 let mut prob_map: std::collections::HashMap<u64, f64> = std::collections::HashMap::new();
4249 for batch in neg_facts {
4250 let Ok(vid_idx) = batch.schema().index_of(right_col) else {
4251 continue;
4252 };
4253 let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
4254 continue;
4255 };
4256 let Some(vids) = batch.column(vid_idx).as_any().downcast_ref::<UInt64Array>() else {
4257 continue;
4258 };
4259 let prob_arr = batch.column(prob_idx);
4260 let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
4261 for i in 0..vids.len() {
4262 if !vids.is_null(i) {
4263 let p = probs
4264 .and_then(|arr| {
4265 if arr.is_null(i) {
4266 None
4267 } else {
4268 Some(arr.value(i))
4269 }
4270 })
4271 .unwrap_or(0.0);
4272 prob_map
4275 .entry(vids.value(i))
4276 .and_modify(|existing| {
4277 *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
4278 })
4279 .or_insert(p);
4280 }
4281 }
4282 }
4283
4284 let mut result = Vec::new();
4286 for batch in batches {
4287 let Ok(idx) = batch.schema().index_of(left_col) else {
4288 result.push(batch);
4289 continue;
4290 };
4291 let Some(vids) = batch.column(idx).as_any().downcast_ref::<UInt64Array>() else {
4292 result.push(batch);
4293 continue;
4294 };
4295
4296 let complements: Vec<f64> = (0..vids.len())
4298 .map(|i| {
4299 if vids.is_null(i) {
4300 1.0
4301 } else {
4302 let p = prob_map.get(&vids.value(i)).copied().unwrap_or(0.0);
4303 1.0 - p
4304 }
4305 })
4306 .collect();
4307
4308 let complement_arr = Float64Array::from(complements);
4309
4310 let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
4312 columns.push(std::sync::Arc::new(complement_arr));
4313
4314 let mut fields: Vec<std::sync::Arc<arrow_schema::Field>> =
4315 batch.schema().fields().iter().cloned().collect();
4316 fields.push(std::sync::Arc::new(arrow_schema::Field::new(
4317 complement_col_name,
4318 arrow_schema::DataType::Float64,
4319 true,
4320 )));
4321
4322 let new_schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4323 let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
4324 result.push(new_batch);
4325 }
4326 Ok(result)
4327}
4328
4329pub fn apply_prob_complement_composite(
4336 batches: Vec<RecordBatch>,
4337 neg_facts: &[RecordBatch],
4338 join_cols: &[(String, String)],
4339 prob_col: &str,
4340 complement_col_name: &str,
4341) -> datafusion::error::Result<Vec<RecordBatch>> {
4342 use arrow_array::{Array as _, Float64Array, UInt64Array};
4343
4344 let mut prob_map: HashMap<Vec<u64>, f64> = HashMap::new();
4346 for batch in neg_facts {
4347 let right_indices: Vec<usize> = join_cols
4348 .iter()
4349 .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
4350 .collect();
4351 if right_indices.len() != join_cols.len() {
4352 continue;
4353 }
4354 let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
4355 continue;
4356 };
4357 let prob_arr = batch.column(prob_idx);
4358 let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
4359 for row in 0..batch.num_rows() {
4360 let mut key = Vec::with_capacity(right_indices.len());
4361 let mut valid = true;
4362 for &ci in &right_indices {
4363 let col = batch.column(ci);
4364 if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
4365 if vids.is_null(row) {
4366 valid = false;
4367 break;
4368 }
4369 key.push(vids.value(row));
4370 } else {
4371 valid = false;
4372 break;
4373 }
4374 }
4375 if !valid {
4376 continue;
4377 }
4378 let p = probs
4379 .and_then(|arr| {
4380 if arr.is_null(row) {
4381 None
4382 } else {
4383 Some(arr.value(row))
4384 }
4385 })
4386 .unwrap_or(0.0);
4387 prob_map
4389 .entry(key)
4390 .and_modify(|existing| {
4391 *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
4392 })
4393 .or_insert(p);
4394 }
4395 }
4396
4397 let mut result = Vec::new();
4399 for batch in batches {
4400 let left_indices: Vec<usize> = join_cols
4401 .iter()
4402 .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
4403 .collect();
4404 if left_indices.len() != join_cols.len() {
4405 result.push(batch);
4406 continue;
4407 }
4408 let all_u64 = left_indices.iter().all(|&ci| {
4409 batch
4410 .column(ci)
4411 .as_any()
4412 .downcast_ref::<UInt64Array>()
4413 .is_some()
4414 });
4415 if !all_u64 {
4416 result.push(batch);
4417 continue;
4418 }
4419
4420 let complements: Vec<f64> = (0..batch.num_rows())
4421 .map(|row| {
4422 let mut key = Vec::with_capacity(left_indices.len());
4423 for &ci in &left_indices {
4424 let vids = batch
4425 .column(ci)
4426 .as_any()
4427 .downcast_ref::<UInt64Array>()
4428 .unwrap();
4429 if vids.is_null(row) {
4430 return 1.0;
4431 }
4432 key.push(vids.value(row));
4433 }
4434 let p = prob_map.get(&key).copied().unwrap_or(0.0);
4435 1.0 - p
4436 })
4437 .collect();
4438
4439 let complement_arr = Float64Array::from(complements);
4440 let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
4441 columns.push(Arc::new(complement_arr));
4442
4443 let mut fields: Vec<Arc<arrow_schema::Field>> =
4444 batch.schema().fields().iter().cloned().collect();
4445 fields.push(Arc::new(arrow_schema::Field::new(
4446 complement_col_name,
4447 arrow_schema::DataType::Float64,
4448 true,
4449 )));
4450
4451 let new_schema = Arc::new(arrow_schema::Schema::new(fields));
4452 let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
4453 result.push(new_batch);
4454 }
4455 Ok(result)
4456}
4457
4458pub fn apply_anti_join_composite(
4464 batches: Vec<RecordBatch>,
4465 neg_facts: &[RecordBatch],
4466 join_cols: &[(String, String)],
4467) -> datafusion::error::Result<Vec<RecordBatch>> {
4468 use arrow::compute::filter_record_batch;
4469 use arrow_array::{Array as _, BooleanArray, UInt64Array};
4470
4471 let mut banned: HashSet<Vec<u64>> = HashSet::new();
4473 for batch in neg_facts {
4474 let right_indices: Vec<usize> = join_cols
4475 .iter()
4476 .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
4477 .collect();
4478 if right_indices.len() != join_cols.len() {
4479 continue;
4480 }
4481 for row in 0..batch.num_rows() {
4482 let mut key = Vec::with_capacity(right_indices.len());
4483 let mut valid = true;
4484 for &ci in &right_indices {
4485 let col = batch.column(ci);
4486 if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
4487 if vids.is_null(row) {
4488 valid = false;
4489 break;
4490 }
4491 key.push(vids.value(row));
4492 } else {
4493 valid = false;
4494 break;
4495 }
4496 }
4497 if valid {
4498 banned.insert(key);
4499 }
4500 }
4501 }
4502
4503 if banned.is_empty() {
4504 return Ok(batches);
4505 }
4506
4507 let mut result = Vec::new();
4509 for batch in batches {
4510 let left_indices: Vec<usize> = join_cols
4511 .iter()
4512 .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
4513 .collect();
4514 if left_indices.len() != join_cols.len() {
4515 result.push(batch);
4516 continue;
4517 }
4518 let all_u64 = left_indices.iter().all(|&ci| {
4519 batch
4520 .column(ci)
4521 .as_any()
4522 .downcast_ref::<UInt64Array>()
4523 .is_some()
4524 });
4525 if !all_u64 {
4526 result.push(batch);
4527 continue;
4528 }
4529
4530 let keep: Vec<bool> = (0..batch.num_rows())
4531 .map(|row| {
4532 let mut key = Vec::with_capacity(left_indices.len());
4533 for &ci in &left_indices {
4534 let vids = batch
4535 .column(ci)
4536 .as_any()
4537 .downcast_ref::<UInt64Array>()
4538 .unwrap();
4539 if vids.is_null(row) {
4540 return true; }
4542 key.push(vids.value(row));
4543 }
4544 !banned.contains(&key)
4545 })
4546 .collect();
4547 let keep_arr = BooleanArray::from(keep);
4548 let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
4549 if filtered.num_rows() > 0 {
4550 result.push(filtered);
4551 }
4552 }
4553 Ok(result)
4554}
4555
4556pub fn multiply_prob_factors(
4567 batches: Vec<RecordBatch>,
4568 prob_col: Option<&str>,
4569 complement_cols: &[String],
4570) -> datafusion::error::Result<Vec<RecordBatch>> {
4571 use arrow_array::{Array as _, Float64Array};
4572
4573 let mut result = Vec::with_capacity(batches.len());
4574
4575 for batch in batches {
4576 if batch.num_rows() == 0 {
4577 let keep: Vec<usize> = batch
4579 .schema()
4580 .fields()
4581 .iter()
4582 .enumerate()
4583 .filter(|(_, f)| !complement_cols.contains(f.name()))
4584 .map(|(i, _)| i)
4585 .collect();
4586 let fields: Vec<_> = keep
4587 .iter()
4588 .map(|&i| batch.schema().field(i).clone())
4589 .collect();
4590 let cols: Vec<_> = keep.iter().map(|&i| batch.column(i).clone()).collect();
4591 let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4592 result.push(
4593 RecordBatch::try_new(schema, cols).map_err(|e| {
4594 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
4595 })?,
4596 );
4597 continue;
4598 }
4599
4600 let num_rows = batch.num_rows();
4601
4602 let mut combined = vec![1.0f64; num_rows];
4604 for col_name in complement_cols {
4605 if let Ok(idx) = batch.schema().index_of(col_name) {
4606 let arr = batch
4607 .column(idx)
4608 .as_any()
4609 .downcast_ref::<Float64Array>()
4610 .ok_or_else(|| {
4611 datafusion::error::DataFusionError::Internal(format!(
4612 "Expected Float64 for complement column {col_name}"
4613 ))
4614 })?;
4615 for (i, val) in combined.iter_mut().enumerate().take(num_rows) {
4616 if !arr.is_null(i) {
4617 *val *= arr.value(i);
4618 }
4619 }
4620 }
4621 }
4622
4623 let final_prob: Vec<f64> = if let Some(prob_name) = prob_col {
4625 if let Ok(idx) = batch.schema().index_of(prob_name) {
4626 let arr = batch
4627 .column(idx)
4628 .as_any()
4629 .downcast_ref::<Float64Array>()
4630 .ok_or_else(|| {
4631 datafusion::error::DataFusionError::Internal(format!(
4632 "Expected Float64 for PROB column {prob_name}"
4633 ))
4634 })?;
4635 (0..num_rows)
4636 .map(|i| {
4637 if arr.is_null(i) {
4638 combined[i]
4639 } else {
4640 arr.value(i) * combined[i]
4641 }
4642 })
4643 .collect()
4644 } else {
4645 combined
4646 }
4647 } else {
4648 combined
4649 };
4650
4651 let new_prob_array: arrow_array::ArrayRef =
4652 std::sync::Arc::new(Float64Array::from(final_prob));
4653
4654 let mut fields = Vec::new();
4656 let mut columns = Vec::new();
4657
4658 for (idx, field) in batch.schema().fields().iter().enumerate() {
4659 if complement_cols.contains(field.name()) {
4660 continue;
4661 }
4662 if prob_col.is_some_and(|p| field.name() == p) {
4663 fields.push(field.clone());
4664 columns.push(new_prob_array.clone());
4665 } else {
4666 fields.push(field.clone());
4667 columns.push(batch.column(idx).clone());
4668 }
4669 }
4670
4671 let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4672 result.push(RecordBatch::try_new(schema, columns).map_err(arrow_err)?);
4673 }
4674
4675 Ok(result)
4676}
4677
4678fn update_derived_scan_handles(
4683 registry: &DerivedScanRegistry,
4684 states: &[FixpointState],
4685 current_rule_idx: usize,
4686 rules: &[FixpointRulePlan],
4687) {
4688 let current_rule_name = &rules[current_rule_idx].name;
4689
4690 for entry in ®istry.entries {
4691 let source_state_idx = rules.iter().position(|r| r.name == entry.rule_name);
4693 let Some(source_idx) = source_state_idx else {
4694 continue;
4695 };
4696
4697 let is_self = entry.rule_name == *current_rule_name;
4698 let data = if is_self && !rules[current_rule_idx].non_linear {
4699 states[source_idx].all_delta().to_vec()
4701 } else {
4702 states[source_idx].all_facts().to_vec()
4705 };
4706
4707 let data = if data.is_empty() || data.iter().all(|b| b.num_rows() == 0) {
4709 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
4710 } else {
4711 data
4712 };
4713
4714 let mut guard = entry.data.write();
4715 *guard = data;
4716 }
4717}
4718
4719pub struct DerivedScanExec {
4729 data: Arc<RwLock<Vec<RecordBatch>>>,
4730 schema: SchemaRef,
4731 properties: Arc<PlanProperties>,
4732}
4733
4734impl DerivedScanExec {
4735 pub fn new(data: Arc<RwLock<Vec<RecordBatch>>>, schema: SchemaRef) -> Self {
4736 let properties = compute_plan_properties(Arc::clone(&schema));
4737 Self {
4738 data,
4739 schema,
4740 properties,
4741 }
4742 }
4743}
4744
4745impl fmt::Debug for DerivedScanExec {
4746 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4747 f.debug_struct("DerivedScanExec")
4748 .field("schema", &self.schema)
4749 .finish()
4750 }
4751}
4752
4753impl DisplayAs for DerivedScanExec {
4754 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4755 write!(f, "DerivedScanExec")
4756 }
4757}
4758
4759impl ExecutionPlan for DerivedScanExec {
4760 fn name(&self) -> &str {
4761 "DerivedScanExec"
4762 }
4763 fn as_any(&self) -> &dyn Any {
4764 self
4765 }
4766 fn schema(&self) -> SchemaRef {
4767 Arc::clone(&self.schema)
4768 }
4769 fn properties(&self) -> &Arc<PlanProperties> {
4770 &self.properties
4771 }
4772 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
4773 vec![]
4774 }
4775 fn with_new_children(
4776 self: Arc<Self>,
4777 _children: Vec<Arc<dyn ExecutionPlan>>,
4778 ) -> DFResult<Arc<dyn ExecutionPlan>> {
4779 Ok(self)
4780 }
4781 fn execute(
4782 &self,
4783 _partition: usize,
4784 _context: Arc<TaskContext>,
4785 ) -> DFResult<SendableRecordBatchStream> {
4786 let batches = {
4787 let guard = self.data.read();
4788 if guard.is_empty() {
4789 vec![RecordBatch::new_empty(Arc::clone(&self.schema))]
4790 } else {
4791 guard
4797 .iter()
4798 .map(|b| {
4799 RecordBatch::try_new(Arc::clone(&self.schema), b.columns().to_vec())
4800 .map_err(|e| {
4801 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
4802 })
4803 })
4804 .collect::<DFResult<Vec<_>>>()?
4805 }
4806 };
4807 Ok(Box::pin(MemoryStream::try_new(
4808 batches,
4809 Arc::clone(&self.schema),
4810 None,
4811 )?))
4812 }
4813}
4814
4815struct InMemoryExec {
4824 batches: Vec<RecordBatch>,
4825 schema: SchemaRef,
4826 properties: Arc<PlanProperties>,
4827}
4828
4829impl InMemoryExec {
4830 fn new(batches: Vec<RecordBatch>, schema: SchemaRef) -> Self {
4831 let properties = compute_plan_properties(Arc::clone(&schema));
4832 Self {
4833 batches,
4834 schema,
4835 properties,
4836 }
4837 }
4838}
4839
4840impl fmt::Debug for InMemoryExec {
4841 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4842 f.debug_struct("InMemoryExec")
4843 .field("num_batches", &self.batches.len())
4844 .field("schema", &self.schema)
4845 .finish()
4846 }
4847}
4848
4849impl DisplayAs for InMemoryExec {
4850 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4851 write!(f, "InMemoryExec: batches={}", self.batches.len())
4852 }
4853}
4854
4855impl ExecutionPlan for InMemoryExec {
4856 fn name(&self) -> &str {
4857 "InMemoryExec"
4858 }
4859 fn as_any(&self) -> &dyn Any {
4860 self
4861 }
4862 fn schema(&self) -> SchemaRef {
4863 Arc::clone(&self.schema)
4864 }
4865 fn properties(&self) -> &Arc<PlanProperties> {
4866 &self.properties
4867 }
4868 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
4869 vec![]
4870 }
4871 fn with_new_children(
4872 self: Arc<Self>,
4873 _children: Vec<Arc<dyn ExecutionPlan>>,
4874 ) -> DFResult<Arc<dyn ExecutionPlan>> {
4875 Ok(self)
4876 }
4877 fn execute(
4878 &self,
4879 _partition: usize,
4880 _context: Arc<TaskContext>,
4881 ) -> DFResult<SendableRecordBatchStream> {
4882 Ok(Box::pin(MemoryStream::try_new(
4883 self.batches.clone(),
4884 Arc::clone(&self.schema),
4885 None,
4886 )?))
4887 }
4888}
4889
4890fn apply_having_filter(
4900 batches: Vec<RecordBatch>,
4901 having_exprs: &[Expr],
4902 schema: &SchemaRef,
4903 task_ctx: &Arc<TaskContext>,
4904) -> DFResult<Vec<RecordBatch>> {
4905 use arrow::compute::{and, filter_record_batch};
4906 use arrow_array::BooleanArray;
4907 use datafusion::common::DFSchema;
4908 use datafusion::logical_expr::LogicalPlanBuilder;
4909 use datafusion::logical_expr::execution_props::ExecutionProps;
4910 use datafusion::optimizer::AnalyzerRule;
4911 use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
4912 use datafusion::physical_expr::create_physical_expr;
4913
4914 if batches.is_empty() {
4915 return Ok(batches);
4916 }
4917
4918 let df_schema = DFSchema::try_from(schema.as_ref().clone()).map_err(|e| {
4920 datafusion::common::DataFusionError::Internal(format!("HAVING schema conversion: {e}"))
4921 })?;
4922
4923 let config = (**task_ctx.session_config().options()).clone();
4928 let props = ExecutionProps::new();
4929
4930 let physical_exprs: Vec<Arc<dyn datafusion::physical_expr::PhysicalExpr>> = having_exprs
4936 .iter()
4937 .map(|expr| {
4938 let df_expr = crate::query::df_expr::cypher_expr_to_df(expr, None).map_err(|e| {
4939 datafusion::common::DataFusionError::Internal(format!(
4940 "HAVING expression conversion: {e}"
4941 ))
4942 })?;
4943
4944 let empty = datafusion::logical_expr::LogicalPlan::EmptyRelation(
4948 datafusion::logical_expr::EmptyRelation {
4949 produce_one_row: false,
4950 schema: Arc::new(df_schema.clone()),
4951 },
4952 );
4953 let filter_plan = LogicalPlanBuilder::from(empty)
4954 .filter(df_expr.clone())?
4955 .build()?;
4956 let coerced_expr = match TypeCoercion::new().analyze(filter_plan, &config) {
4957 Ok(datafusion::logical_expr::LogicalPlan::Filter(f)) => f.predicate,
4958 _ => df_expr,
4959 };
4960
4961 create_physical_expr(&coerced_expr, &df_schema, &props)
4962 })
4963 .collect::<DFResult<Vec<_>>>()?;
4964
4965 let mut result = Vec::new();
4966 for batch in batches {
4967 let mut mask: Option<BooleanArray> = None;
4969 for phys_expr in &physical_exprs {
4970 let value = phys_expr.evaluate(&batch)?;
4971 let arr = value.into_array(batch.num_rows())?;
4972 let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().ok_or_else(|| {
4973 datafusion::common::DataFusionError::Internal(
4974 "HAVING condition must evaluate to boolean".into(),
4975 )
4976 })?;
4977 mask = Some(match mask {
4978 None => bool_arr.clone(),
4979 Some(prev) => and(&prev, bool_arr).map_err(arrow_err)?,
4980 });
4981 }
4982 if let Some(ref m) = mask {
4983 let filtered = filter_record_batch(&batch, m).map_err(arrow_err)?;
4984 if filtered.num_rows() > 0 {
4985 result.push(filtered);
4986 }
4987 } else {
4988 result.push(batch);
4989 }
4990 }
4991 Ok(result)
4992}
4993
4994#[allow(
4996 clippy::too_many_arguments,
4997 reason = "context bundle would be over-engineering for one call site"
4998)]
4999pub(crate) async fn apply_post_fixpoint_chain(
5000 facts: Vec<RecordBatch>,
5001 rule: &FixpointRulePlan,
5002 task_ctx: &Arc<TaskContext>,
5003 strict_probability_domain: bool,
5004 probability_epsilon: f64,
5005 semiring_kind: SemiringKind,
5006 provenance_tracker: Option<Arc<ProvenanceStore>>,
5007 top_k_proofs_k: usize,
5008 registry: Option<Arc<DerivedScanRegistry>>,
5009) -> DFResult<Vec<RecordBatch>> {
5010 if !rule.has_fold && !rule.has_best_by && !rule.has_priority && rule.having.is_empty() {
5011 return Ok(facts);
5012 }
5013
5014 let schema = facts
5019 .iter()
5020 .find(|b| b.num_rows() > 0)
5021 .map(|b| b.schema())
5022 .unwrap_or_else(|| Arc::clone(&rule.yield_schema));
5023
5024 let topk_k: Option<usize> = match semiring_kind {
5038 SemiringKind::TopKProofs { k } if k > 0 => Some(k as usize),
5039 _ => None,
5040 };
5041 let body_support_map: Option<Arc<HashMap<Vec<u8>, Vec<ProofTerm>>>> = if topk_k.is_some()
5042 && !rule.has_priority
5043 && let Some(registry) = registry.as_ref()
5044 {
5045 let mut map: HashMap<Vec<u8>, Vec<ProofTerm>> = HashMap::new();
5046 for batch in &facts {
5047 let all_indices: Vec<usize> = (0..batch.num_columns()).collect();
5048 for row_idx in 0..batch.num_rows() {
5049 let support = collect_is_ref_inputs_for_body_row(rule, batch, row_idx, registry);
5050 if support.is_empty() {
5051 continue;
5052 }
5053 let hash = fact_hash_key(batch, &all_indices, row_idx);
5054 map.insert(hash, support);
5055 }
5056 }
5057 if map.is_empty() {
5058 None
5059 } else {
5060 Some(Arc::new(map))
5061 }
5062 } else {
5063 None
5064 };
5065
5066 let input: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(facts, schema.clone()));
5067
5068 let key_column_indices: Vec<usize> = rule
5073 .key_column_indices
5074 .iter()
5075 .filter_map(|&i| {
5076 let name = rule.yield_schema.field(i).name();
5077 schema.index_of(name).ok()
5078 })
5079 .collect();
5080
5081 let current: Arc<dyn ExecutionPlan> = if rule.has_priority {
5085 let priority_schema = input.schema();
5086 let priority_idx = priority_schema.index_of("__priority").map_err(|_| {
5087 datafusion::common::DataFusionError::Internal(
5088 "PRIORITY rule missing __priority column".to_string(),
5089 )
5090 })?;
5091 Arc::new(PriorityExec::new(
5092 input,
5093 key_column_indices.clone(),
5094 priority_idx,
5095 ))
5096 } else {
5097 input
5098 };
5099
5100 let current: Arc<dyn ExecutionPlan> = if rule.has_fold && !rule.fold_bindings.is_empty() {
5102 Arc::new(FoldExec::new_with_topk(
5103 current,
5104 key_column_indices.clone(),
5105 rule.fold_bindings.clone(),
5106 strict_probability_domain,
5107 probability_epsilon,
5108 semiring_kind,
5109 provenance_tracker.clone(),
5110 topk_k.unwrap_or(top_k_proofs_k),
5111 body_support_map.clone(),
5112 ))
5113 } else {
5114 current
5115 };
5116
5117 let current: Arc<dyn ExecutionPlan> = if !rule.having.is_empty() {
5119 let batches = collect_all_partitions(¤t, Arc::clone(task_ctx)).await?;
5120 let filtered = apply_having_filter(batches, &rule.having, ¤t.schema(), task_ctx)?;
5121 if filtered.is_empty() {
5122 return Ok(filtered);
5123 }
5124 Arc::new(InMemoryExec::new(filtered, Arc::clone(¤t.schema())))
5125 } else {
5126 current
5127 };
5128
5129 let current: Arc<dyn ExecutionPlan> = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
5131 Arc::new(BestByExec::new(
5132 current,
5133 key_column_indices.clone(),
5134 rule.best_by_criteria.clone(),
5135 rule.deterministic,
5136 ))
5137 } else {
5138 current
5139 };
5140
5141 collect_all_partitions(¤t, Arc::clone(task_ctx)).await
5142}
5143
5144pub struct FixpointExec {
5153 rules: Vec<FixpointRulePlan>,
5154 max_iterations: usize,
5155 timeout: Duration,
5156 graph_ctx: Arc<GraphExecutionContext>,
5157 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5158 storage: Arc<StorageManager>,
5159 schema_info: Arc<UniSchema>,
5160 params: HashMap<String, Value>,
5161 derived_scan_registry: Arc<DerivedScanRegistry>,
5162 output_schema: SchemaRef,
5163 properties: Arc<PlanProperties>,
5164 metrics: ExecutionPlanMetricsSet,
5165 max_derived_bytes: usize,
5166 derivation_tracker: Option<Arc<ProvenanceStore>>,
5168 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5170 strict_probability_domain: bool,
5171 probability_epsilon: f64,
5172 exact_probability: bool,
5173 max_bdd_variables: usize,
5174 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5176 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5178 top_k_proofs: usize,
5180 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5182 semiring_kind: SemiringKind,
5184 classifier_registry: Arc<ClassifierRegistry>,
5188 classifier_cache: Option<Arc<ModelInvocationCache>>,
5194 #[allow(
5201 dead_code,
5202 reason = "boundary plumbing; read by EXPLAIN via LocyModelInvokeExec"
5203 )]
5204 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
5205 profile_collector: Option<Arc<LocyProfileCollector>>,
5209}
5210
5211impl fmt::Debug for FixpointExec {
5212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5213 f.debug_struct("FixpointExec")
5214 .field("rules_count", &self.rules.len())
5215 .field("max_iterations", &self.max_iterations)
5216 .field("timeout", &self.timeout)
5217 .field("output_schema", &self.output_schema)
5218 .field("max_derived_bytes", &self.max_derived_bytes)
5219 .finish_non_exhaustive()
5220 }
5221}
5222
5223impl FixpointExec {
5224 #[expect(
5226 clippy::too_many_arguments,
5227 reason = "FixpointExec configuration needs all context"
5228 )]
5229 #[deprecated(
5230 note = "use `new_with_semiring_classifiers_and_cache` (or the lighter \
5231 `new_with_semiring_and_classifiers` / `new_with_semiring`) — \
5232 this legacy ctor defaults the semiring to AddMultProb and \
5233 ships no classifier registry, which the Phase B+ runtime needs \
5234 explicitly. To be removed after C0 Stage 2."
5235 )]
5236 pub fn new(
5237 rules: Vec<FixpointRulePlan>,
5238 max_iterations: usize,
5239 timeout: Duration,
5240 graph_ctx: Arc<GraphExecutionContext>,
5241 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5242 storage: Arc<StorageManager>,
5243 schema_info: Arc<UniSchema>,
5244 params: HashMap<String, Value>,
5245 derived_scan_registry: Arc<DerivedScanRegistry>,
5246 output_schema: SchemaRef,
5247 max_derived_bytes: usize,
5248 derivation_tracker: Option<Arc<ProvenanceStore>>,
5249 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5250 strict_probability_domain: bool,
5251 probability_epsilon: f64,
5252 exact_probability: bool,
5253 max_bdd_variables: usize,
5254 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5255 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5256 top_k_proofs: usize,
5257 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5258 ) -> Self {
5259 Self::new_with_semiring_and_classifiers(
5260 rules,
5261 max_iterations,
5262 timeout,
5263 graph_ctx,
5264 session_ctx,
5265 storage,
5266 schema_info,
5267 params,
5268 derived_scan_registry,
5269 output_schema,
5270 max_derived_bytes,
5271 derivation_tracker,
5272 iteration_counts,
5273 strict_probability_domain,
5274 probability_epsilon,
5275 exact_probability,
5276 max_bdd_variables,
5277 warnings_slot,
5278 approximate_slot,
5279 top_k_proofs,
5280 timeout_flag,
5281 SemiringKind::AddMultProb,
5282 Arc::new(ClassifierRegistry::new()),
5283 )
5284 }
5285
5286 #[expect(
5290 clippy::too_many_arguments,
5291 reason = "FixpointExec configuration needs all context"
5292 )]
5293 pub fn new_with_semiring(
5294 rules: Vec<FixpointRulePlan>,
5295 max_iterations: usize,
5296 timeout: Duration,
5297 graph_ctx: Arc<GraphExecutionContext>,
5298 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5299 storage: Arc<StorageManager>,
5300 schema_info: Arc<UniSchema>,
5301 params: HashMap<String, Value>,
5302 derived_scan_registry: Arc<DerivedScanRegistry>,
5303 output_schema: SchemaRef,
5304 max_derived_bytes: usize,
5305 derivation_tracker: Option<Arc<ProvenanceStore>>,
5306 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5307 strict_probability_domain: bool,
5308 probability_epsilon: f64,
5309 exact_probability: bool,
5310 max_bdd_variables: usize,
5311 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5312 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5313 top_k_proofs: usize,
5314 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5315 semiring_kind: SemiringKind,
5316 ) -> Self {
5317 Self::new_with_semiring_and_classifiers(
5318 rules,
5319 max_iterations,
5320 timeout,
5321 graph_ctx,
5322 session_ctx,
5323 storage,
5324 schema_info,
5325 params,
5326 derived_scan_registry,
5327 output_schema,
5328 max_derived_bytes,
5329 derivation_tracker,
5330 iteration_counts,
5331 strict_probability_domain,
5332 probability_epsilon,
5333 exact_probability,
5334 max_bdd_variables,
5335 warnings_slot,
5336 approximate_slot,
5337 top_k_proofs,
5338 timeout_flag,
5339 semiring_kind,
5340 Arc::new(ClassifierRegistry::new()),
5341 )
5342 }
5343
5344 #[expect(
5348 clippy::too_many_arguments,
5349 reason = "FixpointExec configuration needs all context"
5350 )]
5351 pub fn new_with_semiring_and_classifiers(
5352 rules: Vec<FixpointRulePlan>,
5353 max_iterations: usize,
5354 timeout: Duration,
5355 graph_ctx: Arc<GraphExecutionContext>,
5356 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5357 storage: Arc<StorageManager>,
5358 schema_info: Arc<UniSchema>,
5359 params: HashMap<String, Value>,
5360 derived_scan_registry: Arc<DerivedScanRegistry>,
5361 output_schema: SchemaRef,
5362 max_derived_bytes: usize,
5363 derivation_tracker: Option<Arc<ProvenanceStore>>,
5364 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5365 strict_probability_domain: bool,
5366 probability_epsilon: f64,
5367 exact_probability: bool,
5368 max_bdd_variables: usize,
5369 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5370 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5371 top_k_proofs: usize,
5372 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5373 semiring_kind: SemiringKind,
5374 classifier_registry: Arc<ClassifierRegistry>,
5375 ) -> Self {
5376 Self::new_with_semiring_classifiers_and_cache(
5377 rules,
5378 max_iterations,
5379 timeout,
5380 graph_ctx,
5381 session_ctx,
5382 storage,
5383 schema_info,
5384 params,
5385 derived_scan_registry,
5386 output_schema,
5387 max_derived_bytes,
5388 derivation_tracker,
5389 iteration_counts,
5390 strict_probability_domain,
5391 probability_epsilon,
5392 exact_probability,
5393 max_bdd_variables,
5394 warnings_slot,
5395 approximate_slot,
5396 top_k_proofs,
5397 timeout_flag,
5398 semiring_kind,
5399 classifier_registry,
5400 None,
5401 None,
5402 )
5403 }
5404
5405 #[expect(
5409 clippy::too_many_arguments,
5410 reason = "FixpointExec configuration needs all context"
5411 )]
5412 pub fn new_with_semiring_classifiers_and_cache(
5413 rules: Vec<FixpointRulePlan>,
5414 max_iterations: usize,
5415 timeout: Duration,
5416 graph_ctx: Arc<GraphExecutionContext>,
5417 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5418 storage: Arc<StorageManager>,
5419 schema_info: Arc<UniSchema>,
5420 params: HashMap<String, Value>,
5421 derived_scan_registry: Arc<DerivedScanRegistry>,
5422 output_schema: SchemaRef,
5423 max_derived_bytes: usize,
5424 derivation_tracker: Option<Arc<ProvenanceStore>>,
5425 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5426 strict_probability_domain: bool,
5427 probability_epsilon: f64,
5428 exact_probability: bool,
5429 max_bdd_variables: usize,
5430 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5431 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5432 top_k_proofs: usize,
5433 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5434 semiring_kind: SemiringKind,
5435 classifier_registry: Arc<ClassifierRegistry>,
5436 classifier_cache: Option<Arc<ModelInvocationCache>>,
5437 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
5438 ) -> Self {
5439 let properties = compute_plan_properties(Arc::clone(&output_schema));
5440 Self {
5441 rules,
5442 max_iterations,
5443 timeout,
5444 graph_ctx,
5445 session_ctx,
5446 storage,
5447 schema_info,
5448 params,
5449 derived_scan_registry,
5450 output_schema,
5451 properties,
5452 metrics: ExecutionPlanMetricsSet::new(),
5453 max_derived_bytes,
5454 derivation_tracker,
5455 iteration_counts,
5456 strict_probability_domain,
5457 probability_epsilon,
5458 exact_probability,
5459 max_bdd_variables,
5460 warnings_slot,
5461 approximate_slot,
5462 top_k_proofs,
5463 timeout_flag,
5464 semiring_kind,
5465 classifier_registry,
5466 classifier_cache,
5467 classifier_provenance_store,
5468 profile_collector: None,
5469 }
5470 }
5471
5472 pub fn iteration_counts(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
5474 Arc::clone(&self.iteration_counts)
5475 }
5476
5477 pub fn set_profile_collector(&mut self, collector: Arc<LocyProfileCollector>) {
5483 self.profile_collector = Some(collector);
5484 }
5485}
5486
5487impl DisplayAs for FixpointExec {
5488 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5489 write!(
5490 f,
5491 "FixpointExec: rules=[{}], max_iter={}, timeout={:?}",
5492 self.rules
5493 .iter()
5494 .map(|r| r.name.as_str())
5495 .collect::<Vec<_>>()
5496 .join(", "),
5497 self.max_iterations,
5498 self.timeout,
5499 )
5500 }
5501}
5502
5503impl ExecutionPlan for FixpointExec {
5504 fn name(&self) -> &str {
5505 "FixpointExec"
5506 }
5507
5508 fn as_any(&self) -> &dyn Any {
5509 self
5510 }
5511
5512 fn schema(&self) -> SchemaRef {
5513 Arc::clone(&self.output_schema)
5514 }
5515
5516 fn properties(&self) -> &Arc<PlanProperties> {
5517 &self.properties
5518 }
5519
5520 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
5521 vec![]
5523 }
5524
5525 fn with_new_children(
5526 self: Arc<Self>,
5527 children: Vec<Arc<dyn ExecutionPlan>>,
5528 ) -> DFResult<Arc<dyn ExecutionPlan>> {
5529 if !children.is_empty() {
5530 return Err(datafusion::error::DataFusionError::Plan(
5531 "FixpointExec has no children".to_string(),
5532 ));
5533 }
5534 Ok(self)
5535 }
5536
5537 fn execute(
5538 &self,
5539 partition: usize,
5540 _context: Arc<TaskContext>,
5541 ) -> DFResult<SendableRecordBatchStream> {
5542 let metrics = BaselineMetrics::new(&self.metrics, partition);
5543
5544 let rules = self
5546 .rules
5547 .iter()
5548 .map(|r| {
5549 FixpointRulePlan {
5553 name: r.name.clone(),
5554 clauses: r
5555 .clauses
5556 .iter()
5557 .map(|c| FixpointClausePlan {
5558 body_logical: c.body_logical.clone(),
5559 is_ref_bindings: c.is_ref_bindings.clone(),
5560 priority: c.priority,
5561 along_bindings: c.along_bindings.clone(),
5562 model_invocations: c.model_invocations.clone(),
5563 })
5564 .collect(),
5565 yield_schema: Arc::clone(&r.yield_schema),
5566 key_column_indices: r.key_column_indices.clone(),
5567 priority: r.priority,
5568 has_fold: r.has_fold,
5569 fold_bindings: r.fold_bindings.clone(),
5570 having: r.having.clone(),
5571 has_best_by: r.has_best_by,
5572 best_by_criteria: r.best_by_criteria.clone(),
5573 has_priority: r.has_priority,
5574 deterministic: r.deterministic,
5575 prob_column_name: r.prob_column_name.clone(),
5576 non_linear: r.non_linear,
5577 }
5578 })
5579 .collect();
5580
5581 let max_iterations = self.max_iterations;
5582 let timeout = self.timeout;
5583 let graph_ctx = Arc::clone(&self.graph_ctx);
5584 let session_ctx = Arc::clone(&self.session_ctx);
5585 let storage = Arc::clone(&self.storage);
5586 let schema_info = Arc::clone(&self.schema_info);
5587 let params = self.params.clone();
5588 let registry = Arc::clone(&self.derived_scan_registry);
5589 let output_schema = Arc::clone(&self.output_schema);
5590 let max_derived_bytes = self.max_derived_bytes;
5591 let derivation_tracker = self.derivation_tracker.clone();
5592 let iteration_counts = Arc::clone(&self.iteration_counts);
5593 let strict_probability_domain = self.strict_probability_domain;
5594 let probability_epsilon = self.probability_epsilon;
5595 let exact_probability = self.exact_probability;
5596 let max_bdd_variables = self.max_bdd_variables;
5597 let warnings_slot = Arc::clone(&self.warnings_slot);
5598 let approximate_slot = Arc::clone(&self.approximate_slot);
5599 let top_k_proofs = self.top_k_proofs;
5600 let timeout_flag = Arc::clone(&self.timeout_flag);
5601 let semiring_kind = self.semiring_kind;
5602 let classifier_registry = Arc::clone(&self.classifier_registry);
5603 let classifier_cache = self.classifier_cache.as_ref().map(Arc::clone);
5604 let classifier_provenance_store = self.classifier_provenance_store.as_ref().map(Arc::clone);
5605 let profile_collector = self.profile_collector.as_ref().map(Arc::clone);
5606
5607 let fut = async move {
5608 run_fixpoint_loop(
5609 rules,
5610 max_iterations,
5611 timeout,
5612 graph_ctx,
5613 session_ctx,
5614 storage,
5615 schema_info,
5616 params,
5617 registry,
5618 output_schema,
5619 max_derived_bytes,
5620 derivation_tracker,
5621 iteration_counts,
5622 strict_probability_domain,
5623 probability_epsilon,
5624 exact_probability,
5625 max_bdd_variables,
5626 warnings_slot,
5627 approximate_slot,
5628 top_k_proofs,
5629 timeout_flag,
5630 semiring_kind,
5631 classifier_registry,
5632 classifier_cache,
5633 classifier_provenance_store,
5634 profile_collector,
5635 )
5636 .await
5637 };
5638
5639 Ok(Box::pin(FixpointStream {
5640 state: FixpointStreamState::Running(Box::pin(fut)),
5641 schema: Arc::clone(&self.output_schema),
5642 metrics,
5643 }))
5644 }
5645
5646 fn metrics(&self) -> Option<MetricsSet> {
5647 Some(self.metrics.clone_inner())
5648 }
5649}
5650
5651enum FixpointStreamState {
5656 Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
5658 Emitting(Vec<RecordBatch>, usize),
5660 Done,
5662}
5663
5664struct FixpointStream {
5665 state: FixpointStreamState,
5666 schema: SchemaRef,
5667 metrics: BaselineMetrics,
5668}
5669
5670impl Stream for FixpointStream {
5671 type Item = DFResult<RecordBatch>;
5672
5673 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
5674 let this = self.get_mut();
5675 let metrics = this.metrics.clone();
5676 let _timer = metrics.elapsed_compute().timer();
5677 loop {
5678 match &mut this.state {
5679 FixpointStreamState::Running(fut) => match fut.as_mut().poll(cx) {
5680 Poll::Ready(Ok(batches)) => {
5681 if batches.is_empty() {
5682 this.state = FixpointStreamState::Done;
5683 return Poll::Ready(None);
5684 }
5685 this.state = FixpointStreamState::Emitting(batches, 0);
5686 }
5688 Poll::Ready(Err(e)) => {
5689 this.state = FixpointStreamState::Done;
5690 return Poll::Ready(Some(Err(e)));
5691 }
5692 Poll::Pending => return Poll::Pending,
5693 },
5694 FixpointStreamState::Emitting(batches, idx) => {
5695 if *idx >= batches.len() {
5696 this.state = FixpointStreamState::Done;
5697 return Poll::Ready(None);
5698 }
5699 let batch = batches[*idx].clone();
5700 *idx += 1;
5701 this.metrics.record_output(batch.num_rows());
5702 return Poll::Ready(Some(Ok(batch)));
5703 }
5704 FixpointStreamState::Done => return Poll::Ready(None),
5705 }
5706 }
5707 }
5708}
5709
5710impl RecordBatchStream for FixpointStream {
5711 fn schema(&self) -> SchemaRef {
5712 Arc::clone(&self.schema)
5713 }
5714}
5715
5716#[cfg(test)]
5721mod tests {
5722 use super::*;
5723 use arrow_array::{Float64Array, Int64Array, StringArray};
5724 use arrow_schema::{DataType, Field, Schema};
5725
5726 fn test_schema() -> SchemaRef {
5727 Arc::new(Schema::new(vec![
5728 Field::new("name", DataType::Utf8, true),
5729 Field::new("value", DataType::Int64, true),
5730 ]))
5731 }
5732
5733 fn make_batch(names: &[&str], values: &[i64]) -> RecordBatch {
5734 RecordBatch::try_new(
5735 test_schema(),
5736 vec![
5737 Arc::new(StringArray::from(
5738 names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
5739 )),
5740 Arc::new(Int64Array::from(values.to_vec())),
5741 ],
5742 )
5743 .unwrap()
5744 }
5745
5746 #[tokio::test]
5749 async fn test_fixpoint_state_empty_facts_adds_all() {
5750 let schema = test_schema();
5751 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5752
5753 let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
5754 let changed = state.merge_delta(vec![batch], None).await.unwrap();
5755
5756 assert!(changed);
5757 assert_eq!(state.all_facts().len(), 1);
5758 assert_eq!(state.all_facts()[0].num_rows(), 3);
5759 assert_eq!(state.all_delta().len(), 1);
5760 assert_eq!(state.all_delta()[0].num_rows(), 3);
5761 }
5762
5763 #[tokio::test]
5764 async fn test_fixpoint_state_exact_duplicates_excluded() {
5765 let schema = test_schema();
5766 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5767
5768 let batch1 = make_batch(&["a", "b"], &[1, 2]);
5769 state.merge_delta(vec![batch1], None).await.unwrap();
5770
5771 let batch2 = make_batch(&["a", "b"], &[1, 2]);
5773 let changed = state.merge_delta(vec![batch2], None).await.unwrap();
5774 assert!(!changed);
5775 assert!(
5776 state.all_delta().is_empty() || state.all_delta().iter().all(|b| b.num_rows() == 0)
5777 );
5778 }
5779
5780 #[tokio::test]
5781 async fn test_fixpoint_state_partial_overlap() {
5782 let schema = test_schema();
5783 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5784
5785 let batch1 = make_batch(&["a", "b"], &[1, 2]);
5786 state.merge_delta(vec![batch1], None).await.unwrap();
5787
5788 let batch2 = make_batch(&["a", "c"], &[1, 3]);
5790 let changed = state.merge_delta(vec![batch2], None).await.unwrap();
5791 assert!(changed);
5792
5793 let delta_rows: usize = state.all_delta().iter().map(|b| b.num_rows()).sum();
5795 assert_eq!(delta_rows, 1);
5796
5797 let total_rows: usize = state.all_facts().iter().map(|b| b.num_rows()).sum();
5799 assert_eq!(total_rows, 3);
5800 }
5801
5802 #[tokio::test]
5803 async fn test_fixpoint_state_convergence() {
5804 let schema = test_schema();
5805 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5806
5807 let batch = make_batch(&["a"], &[1]);
5808 state.merge_delta(vec![batch], None).await.unwrap();
5809
5810 let changed = state.merge_delta(vec![], None).await.unwrap();
5812 assert!(!changed);
5813 assert!(state.is_converged());
5814 }
5815
5816 #[test]
5819 fn test_row_dedup_persistent_across_calls() {
5820 let schema = test_schema();
5823 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5824
5825 let batch1 = make_batch(&["a", "b"], &[1, 2]);
5826 let delta1 = rd.compute_delta(&[batch1], &schema).unwrap();
5827 let rows1: usize = delta1.iter().map(|b| b.num_rows()).sum();
5829 assert_eq!(rows1, 2);
5830
5831 let batch2 = make_batch(&["a", "b"], &[1, 2]);
5833 let delta2 = rd.compute_delta(&[batch2], &schema).unwrap();
5834 let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
5835 assert_eq!(rows2, 0);
5836
5837 let batch3 = make_batch(&["a", "c"], &[1, 3]);
5839 let delta3 = rd.compute_delta(&[batch3], &schema).unwrap();
5840 let rows3: usize = delta3.iter().map(|b| b.num_rows()).sum();
5841 assert_eq!(rows3, 1);
5842 }
5843
5844 #[test]
5845 fn test_row_dedup_null_handling() {
5846 use arrow_array::StringArray;
5847 use arrow_schema::{DataType, Field, Schema};
5848
5849 let schema: SchemaRef = Arc::new(Schema::new(vec![
5850 Field::new("a", DataType::Utf8, true),
5851 Field::new("b", DataType::Int64, true),
5852 ]));
5853 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5854
5855 let batch_nulls = RecordBatch::try_new(
5857 Arc::clone(&schema),
5858 vec![
5859 Arc::new(StringArray::from(vec![None::<&str>, None::<&str>])),
5860 Arc::new(arrow_array::Int64Array::from(vec![1i64, 1i64])),
5861 ],
5862 )
5863 .unwrap();
5864 let delta = rd.compute_delta(&[batch_nulls], &schema).unwrap();
5865 let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
5866 assert_eq!(rows, 1, "two identical NULL rows should be deduped to one");
5867
5868 let batch_diff = RecordBatch::try_new(
5870 Arc::clone(&schema),
5871 vec![
5872 Arc::new(StringArray::from(vec![None::<&str>])),
5873 Arc::new(arrow_array::Int64Array::from(vec![2i64])),
5874 ],
5875 )
5876 .unwrap();
5877 let delta2 = rd.compute_delta(&[batch_diff], &schema).unwrap();
5878 let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
5879 assert_eq!(rows2, 1, "(NULL, 2) is distinct from (NULL, 1)");
5880 }
5881
5882 #[test]
5883 fn test_row_dedup_within_candidate_dedup() {
5884 let schema = test_schema();
5886 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5887
5888 let batch = make_batch(&["a", "a", "b"], &[1, 1, 2]);
5890 let delta = rd.compute_delta(&[batch], &schema).unwrap();
5891 let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
5892 assert_eq!(rows, 2, "within-batch dup should be collapsed: a:1, b:2");
5893 }
5894
5895 #[test]
5898 fn test_round_float_columns_near_duplicates() {
5899 let schema = Arc::new(Schema::new(vec![
5900 Field::new("name", DataType::Utf8, true),
5901 Field::new("dist", DataType::Float64, true),
5902 ]));
5903 let batch = RecordBatch::try_new(
5904 Arc::clone(&schema),
5905 vec![
5906 Arc::new(StringArray::from(vec![Some("a"), Some("a")])),
5907 Arc::new(Float64Array::from(vec![1.0000000000001, 1.0000000000002])),
5908 ],
5909 )
5910 .unwrap();
5911
5912 let rounded = round_float_columns(&[batch]);
5913 assert_eq!(rounded.len(), 1);
5914 let col = rounded[0]
5915 .column(1)
5916 .as_any()
5917 .downcast_ref::<Float64Array>()
5918 .unwrap();
5919 assert_eq!(col.value(0), col.value(1));
5921 }
5922
5923 #[test]
5926 fn test_registry_write_read_round_trip() {
5927 let schema = test_schema();
5928 let data = Arc::new(RwLock::new(Vec::new()));
5929 let mut reg = DerivedScanRegistry::new();
5930 reg.add(DerivedScanEntry {
5931 scan_index: 0,
5932 rule_name: "reachable".into(),
5933 is_self_ref: true,
5934 data: Arc::clone(&data),
5935 schema: Arc::clone(&schema),
5936 });
5937
5938 let batch = make_batch(&["x"], &[42]);
5939 reg.write_data(0, vec![batch.clone()]);
5940
5941 let entry = reg.get(0).unwrap();
5942 let guard = entry.data.read();
5943 assert_eq!(guard.len(), 1);
5944 assert_eq!(guard[0].num_rows(), 1);
5945 }
5946
5947 #[test]
5948 fn test_registry_entries_for_rule() {
5949 let schema = test_schema();
5950 let mut reg = DerivedScanRegistry::new();
5951 reg.add(DerivedScanEntry {
5952 scan_index: 0,
5953 rule_name: "r1".into(),
5954 is_self_ref: true,
5955 data: Arc::new(RwLock::new(Vec::new())),
5956 schema: Arc::clone(&schema),
5957 });
5958 reg.add(DerivedScanEntry {
5959 scan_index: 1,
5960 rule_name: "r2".into(),
5961 is_self_ref: false,
5962 data: Arc::new(RwLock::new(Vec::new())),
5963 schema: Arc::clone(&schema),
5964 });
5965 reg.add(DerivedScanEntry {
5966 scan_index: 2,
5967 rule_name: "r1".into(),
5968 is_self_ref: false,
5969 data: Arc::new(RwLock::new(Vec::new())),
5970 schema: Arc::clone(&schema),
5971 });
5972
5973 assert_eq!(reg.entries_for_rule("r1").len(), 2);
5974 assert_eq!(reg.entries_for_rule("r2").len(), 1);
5975 assert_eq!(reg.entries_for_rule("r3").len(), 0);
5976 }
5977
5978 #[test]
5981 fn test_monotonic_agg_update_and_stability() {
5982 let bindings = vec![MonotonicFoldBinding {
5983 fold_name: "total".into(),
5984 aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::SumAgg),
5985 input_col_index: 1,
5986 input_col_name: None,
5987 }];
5988 let mut agg = MonotonicAggState::new(bindings);
5989
5990 let batch = make_batch(&["a"], &[10]);
5992 agg.snapshot();
5993 let changed = agg
5994 .update(&[0], &[batch], false, SemiringKind::AddMultProb)
5995 .unwrap();
5996 assert!(changed);
5997 assert!(!agg.is_stable()); agg.snapshot();
6001 let changed = agg
6002 .update(&[0], &[], false, SemiringKind::AddMultProb)
6003 .unwrap();
6004 assert!(!changed);
6005 assert!(agg.is_stable());
6006 }
6007
6008 #[tokio::test]
6011 async fn test_memory_limit_exceeded() {
6012 let schema = test_schema();
6013 let mut state = FixpointState::new("test".into(), schema, vec![0], 1, None, false);
6015
6016 let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
6017 let result = state.merge_delta(vec![batch], None).await;
6018 assert!(result.is_err());
6019 let err = result.unwrap_err().to_string();
6020 assert!(err.contains("memory limit"), "Error was: {}", err);
6021 }
6022
6023 #[tokio::test]
6026 async fn test_fixpoint_stream_emitting() {
6027 use futures::StreamExt;
6028
6029 let schema = test_schema();
6030 let batch1 = make_batch(&["a"], &[1]);
6031 let batch2 = make_batch(&["b"], &[2]);
6032
6033 let metrics = ExecutionPlanMetricsSet::new();
6034 let baseline = BaselineMetrics::new(&metrics, 0);
6035
6036 let mut stream = FixpointStream {
6037 state: FixpointStreamState::Emitting(vec![batch1, batch2], 0),
6038 schema,
6039 metrics: baseline,
6040 };
6041
6042 let stream = Pin::new(&mut stream);
6043 let batches: Vec<RecordBatch> = stream.filter_map(|r| async { r.ok() }).collect().await;
6044
6045 assert_eq!(batches.len(), 2);
6046 assert_eq!(batches[0].num_rows(), 1);
6047 assert_eq!(batches[1].num_rows(), 1);
6048 }
6049
6050 fn make_f64_batch(names: &[&str], values: &[f64]) -> RecordBatch {
6053 let schema = Arc::new(Schema::new(vec![
6054 Field::new("name", DataType::Utf8, true),
6055 Field::new("value", DataType::Float64, true),
6056 ]));
6057 RecordBatch::try_new(
6058 schema,
6059 vec![
6060 Arc::new(StringArray::from(
6061 names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
6062 )),
6063 Arc::new(Float64Array::from(values.to_vec())),
6064 ],
6065 )
6066 .unwrap()
6067 }
6068
6069 fn make_nor_binding() -> Vec<MonotonicFoldBinding> {
6070 vec![MonotonicFoldBinding {
6071 fold_name: "prob".into(),
6072 aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::MnorAgg),
6073 input_col_index: 1,
6074 input_col_name: None,
6075 }]
6076 }
6077
6078 fn make_prod_binding() -> Vec<MonotonicFoldBinding> {
6079 vec![MonotonicFoldBinding {
6080 fold_name: "prob".into(),
6081 aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::MprodAgg),
6082 input_col_index: 1,
6083 input_col_name: None,
6084 }]
6085 }
6086
6087 fn acc_key(name: &str) -> (Vec<ScalarKey>, String) {
6088 (vec![ScalarKey::Utf8(name.to_string())], "prob".to_string())
6089 }
6090
6091 #[test]
6092 fn test_monotonic_nor_first_update() {
6093 let mut agg = MonotonicAggState::new(make_nor_binding());
6094 let batch = make_f64_batch(&["a"], &[0.3]);
6095 let changed = agg
6096 .update(&[0], &[batch], false, SemiringKind::AddMultProb)
6097 .unwrap();
6098 assert!(changed);
6099 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6100 assert!((val - 0.3).abs() < 1e-10, "expected 0.3, got {}", val);
6101 }
6102
6103 #[test]
6104 fn test_monotonic_nor_two_updates() {
6105 let mut agg = MonotonicAggState::new(make_nor_binding());
6107 let batch1 = make_f64_batch(&["a"], &[0.3]);
6108 agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6109 .unwrap();
6110 let batch2 = make_f64_batch(&["a"], &[0.5]);
6111 agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6112 .unwrap();
6113 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6114 assert!((val - 0.65).abs() < 1e-10, "expected 0.65, got {}", val);
6115 }
6116
6117 #[test]
6118 fn test_monotonic_prod_first_update() {
6119 let mut agg = MonotonicAggState::new(make_prod_binding());
6120 let batch = make_f64_batch(&["a"], &[0.6]);
6121 let changed = agg
6122 .update(&[0], &[batch], false, SemiringKind::AddMultProb)
6123 .unwrap();
6124 assert!(changed);
6125 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6126 assert!((val - 0.6).abs() < 1e-10, "expected 0.6, got {}", val);
6127 }
6128
6129 #[test]
6130 fn test_monotonic_prod_two_updates() {
6131 let mut agg = MonotonicAggState::new(make_prod_binding());
6133 let batch1 = make_f64_batch(&["a"], &[0.6]);
6134 agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6135 .unwrap();
6136 let batch2 = make_f64_batch(&["a"], &[0.8]);
6137 agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6138 .unwrap();
6139 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6140 assert!((val - 0.48).abs() < 1e-10, "expected 0.48, got {}", val);
6141 }
6142
6143 #[test]
6144 fn test_monotonic_nor_stability() {
6145 let mut agg = MonotonicAggState::new(make_nor_binding());
6146 let batch = make_f64_batch(&["a"], &[0.3]);
6147 agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6148 .unwrap();
6149 agg.snapshot();
6150 let changed = agg
6151 .update(&[0], &[], false, SemiringKind::AddMultProb)
6152 .unwrap();
6153 assert!(!changed);
6154 assert!(agg.is_stable());
6155 }
6156
6157 #[test]
6158 fn test_monotonic_prod_stability() {
6159 let mut agg = MonotonicAggState::new(make_prod_binding());
6160 let batch = make_f64_batch(&["a"], &[0.6]);
6161 agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6162 .unwrap();
6163 agg.snapshot();
6164 let changed = agg
6165 .update(&[0], &[], false, SemiringKind::AddMultProb)
6166 .unwrap();
6167 assert!(!changed);
6168 assert!(agg.is_stable());
6169 }
6170
6171 #[test]
6172 fn test_monotonic_nor_multi_group() {
6173 let mut agg = MonotonicAggState::new(make_nor_binding());
6175 let batch1 = make_f64_batch(&["a", "b"], &[0.3, 0.5]);
6176 agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6177 .unwrap();
6178 let batch2 = make_f64_batch(&["a", "b"], &[0.5, 0.2]);
6179 agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6180 .unwrap();
6181
6182 let val_a = agg.get_accumulator(&acc_key("a")).unwrap();
6183 let val_b = agg.get_accumulator(&acc_key("b")).unwrap();
6184 assert!(
6185 (val_a - 0.65).abs() < 1e-10,
6186 "expected a=0.65, got {}",
6187 val_a
6188 );
6189 assert!((val_b - 0.6).abs() < 1e-10, "expected b=0.6, got {}", val_b);
6190 }
6191
6192 #[test]
6193 fn test_monotonic_prod_zero_absorbing() {
6194 let mut agg = MonotonicAggState::new(make_prod_binding());
6196 let batch1 = make_f64_batch(&["a"], &[0.5]);
6197 agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6198 .unwrap();
6199 let batch2 = make_f64_batch(&["a"], &[0.0]);
6200 agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6201 .unwrap();
6202
6203 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6204 assert!((val - 0.0).abs() < 1e-10, "expected 0.0, got {}", val);
6205
6206 agg.snapshot();
6208 let batch3 = make_f64_batch(&["a"], &[0.5]);
6209 let changed = agg
6210 .update(&[0], &[batch3], false, SemiringKind::AddMultProb)
6211 .unwrap();
6212 assert!(!changed);
6213 assert!(agg.is_stable());
6214 }
6215
6216 #[test]
6217 fn test_monotonic_nor_clamping() {
6218 let mut agg = MonotonicAggState::new(make_nor_binding());
6220 let batch = make_f64_batch(&["a"], &[1.5]);
6221 agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6222 .unwrap();
6223 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6224 assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
6225 }
6226
6227 #[test]
6228 fn test_monotonic_nor_absorbing() {
6229 let mut agg = MonotonicAggState::new(make_nor_binding());
6231 let batch1 = make_f64_batch(&["a"], &[0.3]);
6232 agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6233 .unwrap();
6234 let batch2 = make_f64_batch(&["a"], &[1.0]);
6235 agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6236 .unwrap();
6237 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6238 assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
6239 }
6240
6241 #[test]
6244 fn test_monotonic_agg_strict_nor_rejects() {
6245 let mut agg = MonotonicAggState::new(make_nor_binding());
6246 let batch = make_f64_batch(&["a"], &[1.5]);
6247 let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6248 assert!(result.is_err());
6249 let err = result.unwrap_err().to_string();
6250 assert!(
6251 err.contains("strict_probability_domain"),
6252 "Expected strict error, got: {}",
6253 err
6254 );
6255 }
6256
6257 #[test]
6258 fn test_monotonic_agg_strict_prod_rejects() {
6259 let mut agg = MonotonicAggState::new(make_prod_binding());
6260 let batch = make_f64_batch(&["a"], &[2.0]);
6261 let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6262 assert!(result.is_err());
6263 let err = result.unwrap_err().to_string();
6264 assert!(
6265 err.contains("strict_probability_domain"),
6266 "Expected strict error, got: {}",
6267 err
6268 );
6269 }
6270
6271 #[test]
6272 fn test_monotonic_agg_strict_accepts_valid() {
6273 let mut agg = MonotonicAggState::new(make_nor_binding());
6274 let batch = make_f64_batch(&["a"], &[0.5]);
6275 let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6276 assert!(result.is_ok());
6277 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6278 assert!((val - 0.5).abs() < 1e-10, "expected 0.5, got {}", val);
6279 }
6280
6281 fn make_vid_prob_batch(vids: &[u64], probs: &[f64]) -> RecordBatch {
6284 use arrow_array::UInt64Array;
6285 let schema = Arc::new(Schema::new(vec![
6286 Field::new("vid", DataType::UInt64, true),
6287 Field::new("prob", DataType::Float64, true),
6288 ]));
6289 RecordBatch::try_new(
6290 schema,
6291 vec![
6292 Arc::new(UInt64Array::from(vids.to_vec())),
6293 Arc::new(Float64Array::from(probs.to_vec())),
6294 ],
6295 )
6296 .unwrap()
6297 }
6298
6299 #[test]
6300 fn test_prob_complement_basic() {
6301 let body = make_vid_prob_batch(&[1, 2], &[0.9, 0.8]);
6303 let neg = make_vid_prob_batch(&[1], &[0.7]);
6304 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6305 let result = apply_prob_complement_composite(
6306 vec![body],
6307 &[neg],
6308 &join_cols,
6309 "prob",
6310 "__complement_0",
6311 )
6312 .unwrap();
6313 assert_eq!(result.len(), 1);
6314 let batch = &result[0];
6315 let complement = batch
6316 .column_by_name("__complement_0")
6317 .unwrap()
6318 .as_any()
6319 .downcast_ref::<Float64Array>()
6320 .unwrap();
6321 assert!(
6323 (complement.value(0) - 0.3).abs() < 1e-10,
6324 "expected 0.3, got {}",
6325 complement.value(0)
6326 );
6327 assert!(
6329 (complement.value(1) - 1.0).abs() < 1e-10,
6330 "expected 1.0, got {}",
6331 complement.value(1)
6332 );
6333 }
6334
6335 #[test]
6336 fn test_prob_complement_noisy_or_duplicates() {
6337 let body = make_vid_prob_batch(&[1], &[0.9]);
6341 let neg = make_vid_prob_batch(&[1, 1], &[0.3, 0.5]);
6342 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6343 let result = apply_prob_complement_composite(
6344 vec![body],
6345 &[neg],
6346 &join_cols,
6347 "prob",
6348 "__complement_0",
6349 )
6350 .unwrap();
6351 let batch = &result[0];
6352 let complement = batch
6353 .column_by_name("__complement_0")
6354 .unwrap()
6355 .as_any()
6356 .downcast_ref::<Float64Array>()
6357 .unwrap();
6358 assert!(
6359 (complement.value(0) - 0.35).abs() < 1e-10,
6360 "expected 0.35, got {}",
6361 complement.value(0)
6362 );
6363 }
6364
6365 #[test]
6366 fn test_prob_complement_empty_neg() {
6367 let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
6369 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6370 let result =
6371 apply_prob_complement_composite(vec![body], &[], &join_cols, "prob", "__complement_0")
6372 .unwrap();
6373 let batch = &result[0];
6374 let complement = batch
6375 .column_by_name("__complement_0")
6376 .unwrap()
6377 .as_any()
6378 .downcast_ref::<Float64Array>()
6379 .unwrap();
6380 for i in 0..2 {
6381 assert!(
6382 (complement.value(i) - 1.0).abs() < 1e-10,
6383 "row {}: expected 1.0, got {}",
6384 i,
6385 complement.value(i)
6386 );
6387 }
6388 }
6389
6390 #[test]
6391 fn test_anti_join_basic() {
6392 use arrow_array::UInt64Array;
6394 let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
6395 let neg = make_vid_prob_batch(&[2], &[0.0]);
6396 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6397 let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
6398 assert_eq!(result.len(), 1);
6399 let batch = &result[0];
6400 assert_eq!(batch.num_rows(), 2);
6401 let vids = batch
6402 .column_by_name("vid")
6403 .unwrap()
6404 .as_any()
6405 .downcast_ref::<UInt64Array>()
6406 .unwrap();
6407 assert_eq!(vids.value(0), 1);
6408 assert_eq!(vids.value(1), 3);
6409 }
6410
6411 #[test]
6412 fn test_anti_join_empty_neg() {
6413 let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
6415 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6416 let result = apply_anti_join_composite(vec![body], &[], &join_cols).unwrap();
6417 assert_eq!(result.len(), 1);
6418 assert_eq!(result[0].num_rows(), 3);
6419 }
6420
6421 #[test]
6422 fn test_anti_join_all_excluded() {
6423 let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
6425 let neg = make_vid_prob_batch(&[1, 2], &[0.0, 0.0]);
6426 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6427 let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
6428 let total: usize = result.iter().map(|b| b.num_rows()).sum();
6429 assert_eq!(total, 0);
6430 }
6431
6432 #[test]
6433 fn test_multiply_prob_single_complement() {
6434 let body = make_vid_prob_batch(&[1], &[0.8]);
6436 let complement_arr = Float64Array::from(vec![0.5]);
6438 let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
6439 cols.push(Arc::new(complement_arr));
6440 let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
6441 fields.push(Arc::new(Field::new(
6442 "__complement_0",
6443 DataType::Float64,
6444 true,
6445 )));
6446 let schema = Arc::new(Schema::new(fields));
6447 let batch = RecordBatch::try_new(schema, cols).unwrap();
6448
6449 let result =
6450 multiply_prob_factors(vec![batch], Some("prob"), &["__complement_0".to_string()])
6451 .unwrap();
6452 assert_eq!(result.len(), 1);
6453 let out = &result[0];
6454 assert!(out.column_by_name("__complement_0").is_none());
6456 let prob = out
6457 .column_by_name("prob")
6458 .unwrap()
6459 .as_any()
6460 .downcast_ref::<Float64Array>()
6461 .unwrap();
6462 assert!(
6463 (prob.value(0) - 0.4).abs() < 1e-10,
6464 "expected 0.4, got {}",
6465 prob.value(0)
6466 );
6467 }
6468
6469 #[test]
6470 fn test_multiply_prob_multiple_complements() {
6471 let body = make_vid_prob_batch(&[1], &[0.8]);
6473 let c1 = Float64Array::from(vec![0.5]);
6474 let c2 = Float64Array::from(vec![0.6]);
6475 let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
6476 cols.push(Arc::new(c1));
6477 cols.push(Arc::new(c2));
6478 let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
6479 fields.push(Arc::new(Field::new("__c1", DataType::Float64, true)));
6480 fields.push(Arc::new(Field::new("__c2", DataType::Float64, true)));
6481 let schema = Arc::new(Schema::new(fields));
6482 let batch = RecordBatch::try_new(schema, cols).unwrap();
6483
6484 let result = multiply_prob_factors(
6485 vec![batch],
6486 Some("prob"),
6487 &["__c1".to_string(), "__c2".to_string()],
6488 )
6489 .unwrap();
6490 let out = &result[0];
6491 assert!(out.column_by_name("__c1").is_none());
6492 assert!(out.column_by_name("__c2").is_none());
6493 let prob = out
6494 .column_by_name("prob")
6495 .unwrap()
6496 .as_any()
6497 .downcast_ref::<Float64Array>()
6498 .unwrap();
6499 assert!(
6500 (prob.value(0) - 0.24).abs() < 1e-10,
6501 "expected 0.24, got {}",
6502 prob.value(0)
6503 );
6504 }
6505
6506 #[test]
6507 fn test_multiply_prob_no_prob_column() {
6508 use arrow_array::UInt64Array;
6510 let schema = Arc::new(Schema::new(vec![
6511 Field::new("vid", DataType::UInt64, true),
6512 Field::new("__c1", DataType::Float64, true),
6513 ]));
6514 let batch = RecordBatch::try_new(
6515 schema,
6516 vec![
6517 Arc::new(UInt64Array::from(vec![1u64])),
6518 Arc::new(Float64Array::from(vec![0.7])),
6519 ],
6520 )
6521 .unwrap();
6522
6523 let result = multiply_prob_factors(vec![batch], None, &["__c1".to_string()]).unwrap();
6524 let out = &result[0];
6525 assert!(out.column_by_name("__c1").is_none());
6527 assert_eq!(out.num_columns(), 1);
6529 }
6530}