1use crate::query::df_graph::GraphExecutionContext;
10use crate::query::df_graph::common::{
11 ScalarKey, arrow_err, collect_all_partitions, compute_plan_properties, execute_subplan,
12 extract_scalar_key,
13};
14use crate::query::df_graph::locy_best_by::{BestByExec, SortCriterion};
15use crate::query::df_graph::locy_errors::LocyRuntimeError;
16use crate::query::df_graph::locy_explain::{
17 ProofTerm, ProvenanceAnnotation, ProvenanceStore, compute_proof_probability,
18};
19use crate::query::df_graph::locy_fold::{FoldBinding, FoldExec};
20use crate::query::df_graph::locy_priority::PriorityExec;
21use crate::query::df_graph::locy_program::interruption;
22use crate::query::planner::LogicalPlan;
23use arrow_array::RecordBatch;
24use arrow_row::{RowConverter, SortField};
25use arrow_schema::SchemaRef;
26use datafusion::common::JoinType;
27use datafusion::common::Result as DFResult;
28use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
29use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
30use datafusion::physical_plan::memory::MemoryStream;
31use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
32use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
33use futures::Stream;
34use parking_lot::RwLock;
35use std::any::Any;
36use std::collections::{HashMap, HashSet};
37use std::fmt;
38use std::pin::Pin;
39use std::sync::{Arc, RwLock as StdRwLock};
40use std::task::{Context, Poll};
41use std::time::{Duration, Instant};
42use uni_common::Value;
43use uni_common::core::schema::Schema as UniSchema;
44use uni_cypher::ast::Expr;
45use uni_locy::{
46 ClassifierRegistry, ModelInvocation, ModelInvocationCache, RuntimeWarning, RuntimeWarningCode,
47 SemiringKind,
48};
49use uni_store::storage::manager::StorageManager;
50
51#[derive(Debug)]
61pub struct DerivedScanEntry {
62 pub scan_index: usize,
64 pub rule_name: String,
66 pub is_self_ref: bool,
68 pub data: Arc<RwLock<Vec<RecordBatch>>>,
70 pub schema: SchemaRef,
72}
73
74#[derive(Debug, Default)]
81pub struct DerivedScanRegistry {
82 entries: Vec<DerivedScanEntry>,
83}
84
85impl DerivedScanRegistry {
86 pub fn new() -> Self {
88 Self::default()
89 }
90
91 pub fn add(&mut self, entry: DerivedScanEntry) {
93 self.entries.push(entry);
94 }
95
96 pub fn get(&self, scan_index: usize) -> Option<&DerivedScanEntry> {
98 self.entries.iter().find(|e| e.scan_index == scan_index)
99 }
100
101 pub fn write_data(&self, scan_index: usize, batches: Vec<RecordBatch>) {
103 if let Some(entry) = self.get(scan_index) {
104 let mut guard = entry.data.write();
105 *guard = batches;
106 }
107 }
108
109 pub fn entries_for_rule(&self, rule_name: &str) -> Vec<&DerivedScanEntry> {
111 self.entries
112 .iter()
113 .filter(|e| e.rule_name == rule_name)
114 .collect()
115 }
116}
117
118#[derive(Debug, Clone)]
129pub struct MonotonicFoldBinding {
130 pub fold_name: String,
131 pub aggregate: std::sync::Arc<dyn uni_plugin::traits::locy::LocyAggregate>,
132 pub input_col_index: usize,
133 pub input_col_name: Option<String>,
135}
136
137#[derive(Debug)]
143pub struct MonotonicAggState {
144 accumulators: HashMap<(Vec<ScalarKey>, String), f64>,
146 prev_snapshot: HashMap<(Vec<ScalarKey>, String), f64>,
148 bindings: Vec<MonotonicFoldBinding>,
150}
151
152impl MonotonicAggState {
153 pub fn new(bindings: Vec<MonotonicFoldBinding>) -> Self {
155 Self {
156 accumulators: HashMap::new(),
157 prev_snapshot: HashMap::new(),
158 bindings,
159 }
160 }
161
162 pub fn update(
188 &mut self,
189 key_indices: &[usize],
190 delta_batches: &[RecordBatch],
191 strict: bool,
192 semiring_kind: SemiringKind,
193 ) -> DFResult<bool> {
194 let mut changed = false;
195 for batch in delta_batches {
196 for row_idx in 0..batch.num_rows() {
197 let group_key = extract_scalar_key(batch, key_indices, row_idx);
198 for binding in &self.bindings {
199 let idx = binding
200 .input_col_name
201 .as_ref()
202 .and_then(|name| batch.schema().index_of(name).ok())
203 .unwrap_or(binding.input_col_index);
204 if idx >= batch.num_columns() {
205 continue;
206 }
207 let col = batch.column(idx);
208 let val = extract_f64(col.as_ref(), row_idx);
209 if let Some(val) = val {
210 let map_key = (group_key.clone(), binding.fold_name.clone());
211 let initial = binding.aggregate.initial_accum_f64().unwrap_or(0.0);
212 let entry = self.accumulators.entry(map_key).or_insert(initial);
213 let old = *entry;
214 if matches!(semiring_kind, SemiringKind::MaxMinProb)
225 && binding.aggregate.is_probability_aggregate()
226 {
227 use uni_locy::LocySemiring;
228 let sr = uni_locy::MaxMinProb;
229 let is_nor = binding.aggregate.is_noisy_or();
230 let label = if is_nor { "MNOR" } else { "MPROD" };
231 if strict && !(0.0..=1.0).contains(&val) {
232 return Err(datafusion::error::DataFusionError::Execution(
233 format!(
234 "strict_probability_domain: {label} input {val} is outside [0, 1]"
235 ),
236 ));
237 }
238 if !strict && !(0.0..=1.0).contains(&val) {
239 tracing::warn!(
240 "{label} input {val} outside [0,1], clamped to {}",
241 val.clamp(0.0, 1.0)
242 );
243 }
244 let p = val.clamp(0.0, 1.0);
245 *entry = if is_nor {
247 sr.plus(entry, &p)
248 } else {
249 sr.times(entry, &p)
250 };
251 if (*entry - old).abs() > f64::EPSILON {
252 changed = true;
253 }
254 continue;
255 }
256 match binding.aggregate.update_step(*entry, val, strict) {
257 Ok(new_val) => {
258 *entry = new_val;
259 if (*entry - old).abs() > f64::EPSILON {
260 changed = true;
261 }
262 }
263 Err(e) if e.code == uni_plugin::FnError::CODE_UNKNOWN_FUNCTION => {
264 }
268 Err(e) => {
269 return Err(datafusion::error::DataFusionError::Execution(
272 e.message,
273 ));
274 }
275 }
276 }
277 }
278 }
279 }
280 Ok(changed)
281 }
282
283 pub fn snapshot(&mut self) {
285 self.prev_snapshot = self.accumulators.clone();
286 }
287
288 pub fn is_stable(&self) -> bool {
290 if self.accumulators.len() != self.prev_snapshot.len() {
291 return false;
292 }
293 for (key, val) in &self.accumulators {
294 match self.prev_snapshot.get(key) {
295 Some(prev) if (*val - *prev).abs() <= f64::EPSILON => {}
296 _ => return false,
297 }
298 }
299 true
300 }
301
302 #[cfg(test)]
304 pub(crate) fn get_accumulator(&self, key: &(Vec<ScalarKey>, String)) -> Option<f64> {
305 self.accumulators.get(key).copied()
306 }
307}
308
309fn extract_f64(col: &dyn arrow_array::Array, row_idx: usize) -> Option<f64> {
311 if col.is_null(row_idx) {
312 return None;
313 }
314 if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Float64Array>() {
315 Some(arr.value(row_idx))
316 } else {
317 col.as_any()
318 .downcast_ref::<arrow_array::Int64Array>()
319 .map(|arr| arr.value(row_idx) as f64)
320 }
321}
322
323struct RowDedupState {
333 converter: RowConverter,
334 seen: HashSet<Box<[u8]>>,
335}
336
337impl RowDedupState {
338 fn try_new(schema: &SchemaRef) -> Option<Self> {
343 let fields: Vec<SortField> = schema
344 .fields()
345 .iter()
346 .map(|f| SortField::new(f.data_type().clone()))
347 .collect();
348 match RowConverter::new(fields) {
349 Ok(converter) => Some(Self {
350 converter,
351 seen: HashSet::new(),
352 }),
353 Err(e) => {
354 tracing::warn!(
355 "RowDedupState: RowConverter unsupported for schema, falling back to legacy dedup: {}",
356 e
357 );
358 None
359 }
360 }
361 }
362
363 fn ingest_existing(&mut self, facts: &[RecordBatch], _schema: &SchemaRef) {
368 self.seen.clear();
369 for batch in facts {
370 if batch.num_rows() == 0 {
371 continue;
372 }
373 let arrays: Vec<_> = batch.columns().to_vec();
374 if let Ok(rows) = self.converter.convert_columns(&arrays) {
375 for row_idx in 0..batch.num_rows() {
376 let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
377 self.seen.insert(row_bytes);
378 }
379 }
380 }
381 }
382
383 fn compute_delta(
389 &mut self,
390 candidates: &[RecordBatch],
391 schema: &SchemaRef,
392 ) -> DFResult<Vec<RecordBatch>> {
393 let mut delta_batches = Vec::new();
394 for batch in candidates {
395 if batch.num_rows() == 0 {
396 continue;
397 }
398
399 let arrays: Vec<_> = batch.columns().to_vec();
401 let rows = self.converter.convert_columns(&arrays).map_err(arrow_err)?;
402
403 let mut keep = Vec::with_capacity(batch.num_rows());
405 for row_idx in 0..batch.num_rows() {
406 let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
407 keep.push(self.seen.insert(row_bytes));
408 }
409
410 let keep_mask = arrow_array::BooleanArray::from(keep);
411 let new_cols = batch
412 .columns()
413 .iter()
414 .map(|col| {
415 arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
416 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
417 })
418 })
419 .collect::<DFResult<Vec<_>>>()?;
420
421 if new_cols.first().is_some_and(|c| !c.is_empty()) {
422 let filtered = RecordBatch::try_new(Arc::clone(schema), new_cols).map_err(|e| {
423 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
424 })?;
425 delta_batches.push(filtered);
426 }
427 }
428 Ok(delta_batches)
429 }
430}
431
432pub struct FixpointState {
442 rule_name: String,
443 facts: Vec<RecordBatch>,
444 delta: Vec<RecordBatch>,
445 schema: SchemaRef,
446 key_column_indices: Vec<usize>,
447 key_column_names: Vec<String>,
449 all_column_indices: Vec<usize>,
451 facts_bytes: usize,
453 max_derived_bytes: usize,
455 monotonic_agg: Option<MonotonicAggState>,
457 row_dedup: Option<RowDedupState>,
459 strict_probability_domain: bool,
461 semiring_kind: SemiringKind,
463}
464
465impl FixpointState {
466 pub fn new(
471 rule_name: String,
472 schema: SchemaRef,
473 key_column_indices: Vec<usize>,
474 max_derived_bytes: usize,
475 monotonic_agg: Option<MonotonicAggState>,
476 strict_probability_domain: bool,
477 ) -> Self {
478 Self::new_with_semiring(
479 rule_name,
480 schema,
481 key_column_indices,
482 max_derived_bytes,
483 monotonic_agg,
484 strict_probability_domain,
485 SemiringKind::AddMultProb,
486 )
487 }
488
489 pub fn new_with_semiring(
490 rule_name: String,
491 schema: SchemaRef,
492 key_column_indices: Vec<usize>,
493 max_derived_bytes: usize,
494 monotonic_agg: Option<MonotonicAggState>,
495 strict_probability_domain: bool,
496 semiring_kind: SemiringKind,
497 ) -> Self {
498 let num_cols = schema.fields().len();
499 let row_dedup = RowDedupState::try_new(&schema);
500 let key_column_names: Vec<String> = key_column_indices
501 .iter()
502 .filter_map(|&i| schema.fields().get(i).map(|f| f.name().clone()))
503 .collect();
504 Self {
505 rule_name,
506 facts: Vec::new(),
507 delta: Vec::new(),
508 schema,
509 key_column_indices,
510 key_column_names,
511 all_column_indices: (0..num_cols).collect(),
512 facts_bytes: 0,
513 max_derived_bytes,
514 monotonic_agg,
515 row_dedup,
516 strict_probability_domain,
517 semiring_kind,
518 }
519 }
520
521 fn reconcile_schema(&mut self, actual_schema: &SchemaRef) {
528 if self.schema.fields() != actual_schema.fields() {
529 tracing::debug!(
530 rule = %self.rule_name,
531 "Reconciling fixpoint schema from physical plan output",
532 );
533 self.schema = Arc::clone(actual_schema);
534 self.row_dedup = RowDedupState::try_new(&self.schema);
535 let new_indices: Vec<usize> = self
539 .key_column_names
540 .iter()
541 .filter_map(|name| actual_schema.index_of(name).ok())
542 .collect();
543 if new_indices.len() == self.key_column_names.len() {
544 self.key_column_indices = new_indices;
545 }
546 let num_cols = actual_schema.fields().len();
548 self.all_column_indices = (0..num_cols).collect();
549 }
550 }
551
552 pub async fn merge_delta(
556 &mut self,
557 candidates: Vec<RecordBatch>,
558 task_ctx: Option<Arc<TaskContext>>,
559 ) -> DFResult<bool> {
560 if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
561 self.delta.clear();
562 return Ok(false);
563 }
564
565 if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
569 self.reconcile_schema(&first.schema());
570 }
571
572 let candidates = round_float_columns(&candidates);
574
575 let delta = self.compute_delta(&candidates, task_ctx.as_ref()).await?;
577
578 if delta.is_empty() || delta.iter().all(|b| b.num_rows() == 0) {
579 self.delta.clear();
580 if let Some(ref mut agg) = self.monotonic_agg {
582 agg.snapshot();
583 }
584 return Ok(false);
585 }
586
587 let delta_bytes: usize = delta.iter().map(batch_byte_size).sum();
589 if self.facts_bytes + delta_bytes > self.max_derived_bytes {
590 return Err(datafusion::error::DataFusionError::Execution(
591 LocyRuntimeError::MemoryLimitExceeded {
592 rule: self.rule_name.clone(),
593 bytes: self.facts_bytes + delta_bytes,
594 limit: self.max_derived_bytes,
595 }
596 .to_string(),
597 ));
598 }
599
600 if let Some(ref mut agg) = self.monotonic_agg {
602 agg.snapshot();
603 agg.update(
604 &self.key_column_indices,
605 &delta,
606 self.strict_probability_domain,
607 self.semiring_kind,
608 )?;
609 }
610
611 self.facts_bytes += delta_bytes;
613 self.facts.extend(delta.iter().cloned());
614 self.delta = delta;
615
616 Ok(true)
617 }
618
619 async fn compute_delta(
626 &mut self,
627 candidates: &[RecordBatch],
628 task_ctx: Option<&Arc<TaskContext>>,
629 ) -> DFResult<Vec<RecordBatch>> {
630 let total_existing: usize = self.facts.iter().map(|b| b.num_rows()).sum();
631 if total_existing >= DEDUP_ANTI_JOIN_THRESHOLD
632 && let Some(ctx) = task_ctx
633 {
634 return arrow_left_anti_dedup(candidates.to_vec(), &self.facts, &self.schema, ctx)
635 .await;
636 }
637 if let Some(ref mut rd) = self.row_dedup {
638 rd.compute_delta(candidates, &self.schema)
639 } else {
640 self.compute_delta_legacy(candidates)
641 }
642 }
643
644 fn compute_delta_legacy(&self, candidates: &[RecordBatch]) -> DFResult<Vec<RecordBatch>> {
648 let mut existing: HashSet<Vec<ScalarKey>> = HashSet::new();
650 for batch in &self.facts {
651 for row_idx in 0..batch.num_rows() {
652 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
653 existing.insert(key);
654 }
655 }
656
657 let mut delta_batches = Vec::new();
658 for batch in candidates {
659 if batch.num_rows() == 0 {
660 continue;
661 }
662 let mut keep = Vec::with_capacity(batch.num_rows());
664 for row_idx in 0..batch.num_rows() {
665 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
666 keep.push(!existing.contains(&key));
667 }
668
669 for (row_idx, kept) in keep.iter_mut().enumerate() {
671 if *kept {
672 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
673 if !existing.insert(key) {
674 *kept = false;
675 }
676 }
677 }
678
679 let keep_mask = arrow_array::BooleanArray::from(keep);
680 let new_rows = batch
681 .columns()
682 .iter()
683 .map(|col| {
684 arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
685 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
686 })
687 })
688 .collect::<DFResult<Vec<_>>>()?;
689
690 if new_rows.first().is_some_and(|c| !c.is_empty()) {
691 let filtered =
692 RecordBatch::try_new(Arc::clone(&self.schema), new_rows).map_err(|e| {
693 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
694 })?;
695 delta_batches.push(filtered);
696 }
697 }
698
699 Ok(delta_batches)
700 }
701
702 pub fn is_converged(&self) -> bool {
704 let delta_empty = self.delta.is_empty() || self.delta.iter().all(|b| b.num_rows() == 0);
705 let agg_stable = self.monotonic_agg.as_ref().is_none_or(|a| a.is_stable());
706 delta_empty && agg_stable
707 }
708
709 pub fn all_facts(&self) -> &[RecordBatch] {
711 &self.facts
712 }
713
714 pub fn all_delta(&self) -> &[RecordBatch] {
716 &self.delta
717 }
718
719 pub fn into_facts(self) -> Vec<RecordBatch> {
721 self.facts
722 }
723
724 pub fn merge_best_by(
735 &mut self,
736 candidates: Vec<RecordBatch>,
737 sort_criteria: &[SortCriterion],
738 ) -> DFResult<bool> {
739 if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
740 self.delta.clear();
741 return Ok(false);
742 }
743
744 if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
746 self.reconcile_schema(&first.schema());
747 }
748
749 let candidates = round_float_columns(&candidates);
751
752 let old_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> =
754 self.build_key_criteria_map(sort_criteria);
755
756 let mut all_batches = self.facts.clone();
758 all_batches.extend(candidates);
759 let all_batches: Vec<_> = all_batches
760 .into_iter()
761 .filter(|b| b.num_rows() > 0)
762 .collect();
763 if all_batches.is_empty() {
764 self.delta.clear();
765 return Ok(false);
766 }
767
768 let combined = arrow::compute::concat_batches(&self.schema, &all_batches)
769 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
770
771 if combined.num_rows() == 0 {
772 self.delta.clear();
773 return Ok(false);
774 }
775
776 let mut sort_columns = Vec::new();
779 for &ki in &self.key_column_indices {
780 if ki >= combined.num_columns() {
781 continue;
782 }
783 sort_columns.push(arrow::compute::SortColumn {
784 values: Arc::clone(combined.column(ki)),
785 options: Some(arrow::compute::SortOptions {
786 descending: false,
787 nulls_first: false,
788 }),
789 });
790 }
791 for criterion in sort_criteria {
792 if criterion.col_index >= combined.num_columns() {
793 continue;
794 }
795 sort_columns.push(arrow::compute::SortColumn {
796 values: Arc::clone(combined.column(criterion.col_index)),
797 options: Some(arrow::compute::SortOptions {
798 descending: !criterion.ascending,
799 nulls_first: criterion.nulls_first,
800 }),
801 });
802 }
803
804 let sorted_indices =
805 arrow::compute::lexsort_to_indices(&sort_columns, None).map_err(arrow_err)?;
806 let sorted_columns: Vec<_> = combined
807 .columns()
808 .iter()
809 .map(|col| arrow::compute::take(col.as_ref(), &sorted_indices, None))
810 .collect::<Result<Vec<_>, _>>()
811 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
812 let sorted = RecordBatch::try_new(Arc::clone(&self.schema), sorted_columns)
813 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
814
815 let mut keep_indices: Vec<u32> = Vec::new();
817 let mut prev_key: Option<Vec<ScalarKey>> = None;
818 for row_idx in 0..sorted.num_rows() {
819 let key = extract_scalar_key(&sorted, &self.key_column_indices, row_idx);
820 let is_new_group = match &prev_key {
821 None => true,
822 Some(prev) => *prev != key,
823 };
824 if is_new_group {
825 keep_indices.push(row_idx as u32);
826 prev_key = Some(key);
827 }
828 }
829
830 let keep_array = arrow_array::UInt32Array::from(keep_indices);
831 let output_columns: Vec<_> = sorted
832 .columns()
833 .iter()
834 .map(|col| arrow::compute::take(col.as_ref(), &keep_array, None))
835 .collect::<Result<Vec<_>, _>>()
836 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
837 let pruned = RecordBatch::try_new(Arc::clone(&self.schema), output_columns)
838 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
839
840 let new_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> = {
842 let mut map = HashMap::new();
843 for row_idx in 0..pruned.num_rows() {
844 let key = extract_scalar_key(&pruned, &self.key_column_indices, row_idx);
845 let criteria: Vec<ScalarKey> = sort_criteria
846 .iter()
847 .flat_map(|c| extract_scalar_key(&pruned, &[c.col_index], row_idx))
848 .collect();
849 map.insert(key, criteria);
850 }
851 map
852 };
853 let changed = old_best != new_best;
854
855 tracing::debug!(
856 rule = %self.rule_name,
857 old_keys = old_best.len(),
858 new_keys = new_best.len(),
859 changed = changed,
860 "BEST BY merge"
861 );
862
863 self.facts_bytes = batch_byte_size(&pruned);
865 self.facts = vec![pruned];
866 if changed {
867 self.delta = self.facts.clone();
870 } else {
871 self.delta.clear();
872 }
873
874 self.row_dedup = RowDedupState::try_new(&self.schema);
876 if let Some(ref mut rd) = self.row_dedup {
877 rd.ingest_existing(&self.facts, &self.schema);
878 }
879
880 Ok(changed)
881 }
882
883 fn build_key_criteria_map(
885 &self,
886 sort_criteria: &[SortCriterion],
887 ) -> HashMap<Vec<ScalarKey>, Vec<ScalarKey>> {
888 let mut map = HashMap::new();
889 for batch in &self.facts {
890 for row_idx in 0..batch.num_rows() {
891 let key = extract_scalar_key(batch, &self.key_column_indices, row_idx);
892 let criteria: Vec<ScalarKey> = sort_criteria
893 .iter()
894 .flat_map(|c| extract_scalar_key(batch, &[c.col_index], row_idx))
895 .collect();
896 map.insert(key, criteria);
897 }
898 }
899 map
900 }
901}
902
903fn batch_byte_size(batch: &RecordBatch) -> usize {
905 batch
906 .columns()
907 .iter()
908 .map(|col| col.get_buffer_memory_size())
909 .sum()
910}
911
912fn round_float_columns(batches: &[RecordBatch]) -> Vec<RecordBatch> {
918 batches
919 .iter()
920 .map(|batch| {
921 let schema = batch.schema();
922 let has_float = schema
923 .fields()
924 .iter()
925 .any(|f| *f.data_type() == arrow_schema::DataType::Float64);
926 if !has_float {
927 return batch.clone();
928 }
929
930 let columns: Vec<arrow_array::ArrayRef> = batch
931 .columns()
932 .iter()
933 .enumerate()
934 .map(|(i, col)| {
935 if *schema.field(i).data_type() == arrow_schema::DataType::Float64 {
936 let arr = col
937 .as_any()
938 .downcast_ref::<arrow_array::Float64Array>()
939 .unwrap();
940 let rounded: arrow_array::Float64Array = arr
941 .iter()
942 .map(|v| v.map(|f| (f * 1e12).round() / 1e12))
943 .collect();
944 Arc::new(rounded) as arrow_array::ArrayRef
945 } else {
946 Arc::clone(col)
947 }
948 })
949 .collect();
950
951 RecordBatch::try_new(schema, columns).unwrap_or_else(|_| batch.clone())
952 })
953 .collect()
954}
955
956const DEDUP_ANTI_JOIN_THRESHOLD: usize = 300;
966
967fn dedup_batches_all_columns(
982 batches: Vec<RecordBatch>,
983 schema: &SchemaRef,
984) -> DFResult<Vec<RecordBatch>> {
985 let fields: Vec<SortField> = schema
986 .fields()
987 .iter()
988 .map(|f| SortField::new(f.data_type().clone()))
989 .collect();
990 let Ok(converter) = RowConverter::new(fields) else {
994 return Ok(batches);
995 };
996 let mut seen: HashSet<Box<[u8]>> = HashSet::new();
997 let mut out = Vec::with_capacity(batches.len());
998 for batch in batches {
999 if batch.num_rows() == 0 {
1000 continue;
1001 }
1002 let rows = converter
1003 .convert_columns(batch.columns())
1004 .map_err(arrow_err)?;
1005 let mut keep = Vec::with_capacity(batch.num_rows());
1006 for row_idx in 0..batch.num_rows() {
1007 let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
1008 keep.push(seen.insert(row_bytes));
1009 }
1010 let keep_mask = arrow_array::BooleanArray::from(keep);
1011 let cols = batch
1012 .columns()
1013 .iter()
1014 .map(|c| arrow::compute::filter(c.as_ref(), &keep_mask).map_err(arrow_err))
1015 .collect::<DFResult<Vec<_>>>()?;
1016 if cols.first().is_some_and(|c| !c.is_empty()) {
1017 out.push(RecordBatch::try_new(Arc::clone(schema), cols).map_err(arrow_err)?);
1018 }
1019 }
1020 Ok(out)
1021}
1022
1023async fn arrow_left_anti_dedup(
1024 candidates: Vec<RecordBatch>,
1025 existing: &[RecordBatch],
1026 schema: &SchemaRef,
1027 task_ctx: &Arc<TaskContext>,
1028) -> DFResult<Vec<RecordBatch>> {
1029 if existing.is_empty() || existing.iter().all(|b| b.num_rows() == 0) {
1030 return dedup_batches_all_columns(candidates, schema);
1033 }
1034
1035 let left: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(candidates, Arc::clone(schema)));
1036 let right: Arc<dyn ExecutionPlan> =
1037 Arc::new(InMemoryExec::new(existing.to_vec(), Arc::clone(schema)));
1038
1039 let on: Vec<(
1040 Arc<dyn datafusion::physical_plan::PhysicalExpr>,
1041 Arc<dyn datafusion::physical_plan::PhysicalExpr>,
1042 )> = schema
1043 .fields()
1044 .iter()
1045 .enumerate()
1046 .map(|(i, field)| {
1047 let l: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
1048 datafusion::physical_plan::expressions::Column::new(field.name(), i),
1049 );
1050 let r: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
1051 datafusion::physical_plan::expressions::Column::new(field.name(), i),
1052 );
1053 (l, r)
1054 })
1055 .collect();
1056
1057 if on.is_empty() {
1058 return Ok(vec![]);
1059 }
1060
1061 let join = HashJoinExec::try_new(
1062 left,
1063 right,
1064 on,
1065 None,
1066 &JoinType::LeftAnti,
1067 None,
1068 PartitionMode::CollectLeft,
1069 datafusion::common::NullEquality::NullEqualsNull,
1070 false,
1075 )?;
1076
1077 let join_arc: Arc<dyn ExecutionPlan> = Arc::new(join);
1078 let anti = collect_all_partitions(&join_arc, task_ctx.clone()).await?;
1081 dedup_batches_all_columns(anti, schema)
1082}
1083
1084#[derive(Debug, Clone)]
1090pub struct IsRefBinding {
1091 pub derived_scan_index: usize,
1093 pub rule_name: String,
1095 pub is_self_ref: bool,
1097 pub negated: bool,
1099 pub anti_join_cols: Vec<(String, String)>,
1105 pub target_has_prob: bool,
1107 pub target_prob_col: Option<String>,
1109 pub provenance_join_cols: Vec<(String, String)>,
1114}
1115
1116#[derive(Debug)]
1118pub struct FixpointClausePlan {
1119 pub body_logical: LogicalPlan,
1121 pub is_ref_bindings: Vec<IsRefBinding>,
1123 pub priority: Option<i64>,
1125 pub along_bindings: Vec<String>,
1127 pub model_invocations: Vec<ModelInvocation>,
1131}
1132
1133#[derive(Debug)]
1135pub struct FixpointRulePlan {
1136 pub name: String,
1138 pub clauses: Vec<FixpointClausePlan>,
1140 pub yield_schema: SchemaRef,
1142 pub key_column_indices: Vec<usize>,
1144 pub priority: Option<i64>,
1146 pub has_fold: bool,
1148 pub fold_bindings: Vec<FoldBinding>,
1150 pub having: Vec<Expr>,
1152 pub has_best_by: bool,
1154 pub best_by_criteria: Vec<SortCriterion>,
1156 pub has_priority: bool,
1158 pub deterministic: bool,
1162 pub prob_column_name: Option<String>,
1164}
1165
1166#[expect(clippy::too_many_arguments, reason = "Fixpoint loop needs all context")]
1175async fn run_fixpoint_loop(
1176 rules: Vec<FixpointRulePlan>,
1177 max_iterations: usize,
1178 timeout: Duration,
1179 graph_ctx: Arc<GraphExecutionContext>,
1180 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
1181 storage: Arc<StorageManager>,
1182 schema_info: Arc<UniSchema>,
1183 params: HashMap<String, Value>,
1184 registry: Arc<DerivedScanRegistry>,
1185 output_schema: SchemaRef,
1186 max_derived_bytes: usize,
1187 derivation_tracker: Option<Arc<ProvenanceStore>>,
1188 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
1189 strict_probability_domain: bool,
1190 probability_epsilon: f64,
1191 exact_probability: bool,
1192 max_bdd_variables: usize,
1193 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
1194 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
1195 top_k_proofs: usize,
1196 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
1197 semiring_kind: SemiringKind,
1198 classifier_registry: Arc<ClassifierRegistry>,
1199 classifier_cache: Option<Arc<ModelInvocationCache>>,
1200 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
1201) -> DFResult<Vec<RecordBatch>> {
1202 let start = Instant::now();
1203 let task_ctx = session_ctx.read().task_ctx();
1204
1205 if semiring_kind == SemiringKind::MaxMinProb {
1210 let mut warnings = warnings_slot.write().unwrap_or_else(|e| e.into_inner());
1211 let mut already_warned: HashSet<String> = warnings
1212 .iter()
1213 .filter(|w| w.code == RuntimeWarningCode::FuzzyNotProbabilistic)
1214 .map(|w| w.rule_name.clone())
1215 .collect();
1216 for rule in &rules {
1217 if rule.prob_column_name.is_some() && !already_warned.contains(&rule.name) {
1218 warnings.push(RuntimeWarning {
1219 code: RuntimeWarningCode::FuzzyNotProbabilistic,
1220 message: format!(
1221 "rule '{}' carries a PROB column but is being evaluated under \
1222 the MaxMinProb (fuzzy / Viterbi) semiring; outputs are fuzzy \
1223 truth values, not probabilities",
1224 rule.name
1225 ),
1226 rule_name: rule.name.clone(),
1227 variable_count: None,
1228 key_group: None,
1229 });
1230 already_warned.insert(rule.name.clone());
1231 }
1232 }
1233 }
1234
1235 let mut states: Vec<FixpointState> = rules
1237 .iter()
1238 .map(|rule| {
1239 let monotonic_agg = if !rule.fold_bindings.is_empty() {
1240 let bindings: Vec<MonotonicFoldBinding> = rule
1241 .fold_bindings
1242 .iter()
1243 .map(|fb| MonotonicFoldBinding {
1244 fold_name: fb.output_name.clone(),
1245 aggregate: std::sync::Arc::clone(&fb.aggregate),
1246 input_col_index: fb.input_col_index,
1247 input_col_name: fb.input_col_name.clone(),
1248 })
1249 .collect();
1250 Some(MonotonicAggState::new(bindings))
1251 } else {
1252 None
1253 };
1254 FixpointState::new_with_semiring(
1255 rule.name.clone(),
1256 Arc::clone(&rule.yield_schema),
1257 rule.key_column_indices.clone(),
1258 max_derived_bytes,
1259 monotonic_agg,
1260 strict_probability_domain,
1261 semiring_kind,
1262 )
1263 })
1264 .collect();
1265
1266 let mut converged = false;
1268 let mut total_iters = 0usize;
1269 for iteration in 0..max_iterations {
1270 total_iters = iteration + 1;
1271 tracing::debug!("fixpoint iteration {}", iteration);
1272 let mut any_changed = false;
1273
1274 for rule_idx in 0..rules.len() {
1275 let rule = &rules[rule_idx];
1276
1277 update_derived_scan_handles(®istry, &states, rule_idx, &rules);
1279
1280 let mut all_candidates = Vec::new();
1282 let mut clause_candidates: Vec<Vec<RecordBatch>> = Vec::new();
1283 for clause in &rule.clauses {
1284 let mut batches = execute_subplan(
1290 &clause.body_logical,
1291 ¶ms,
1292 &HashMap::new(),
1293 &graph_ctx,
1294 &session_ctx,
1295 &storage,
1296 &schema_info,
1297 None, )
1299 .await?;
1300 for binding in &clause.is_ref_bindings {
1302 if binding.negated
1303 && !binding.anti_join_cols.is_empty()
1304 && let Some(entry) = registry.get(binding.derived_scan_index)
1305 {
1306 let neg_facts = entry.data.read().clone();
1307 if !neg_facts.is_empty() {
1308 if binding.target_has_prob && rule.prob_column_name.is_some() {
1309 let complement_col =
1311 format!("__prob_complement_{}", binding.rule_name);
1312 if let Some(prob_col) = &binding.target_prob_col {
1313 batches = apply_prob_complement_composite(
1314 batches,
1315 &neg_facts,
1316 &binding.anti_join_cols,
1317 prob_col,
1318 &complement_col,
1319 )?;
1320 } else {
1321 batches = apply_anti_join_composite(
1323 batches,
1324 &neg_facts,
1325 &binding.anti_join_cols,
1326 )?;
1327 }
1328 } else {
1329 batches = apply_anti_join_composite(
1331 batches,
1332 &neg_facts,
1333 &binding.anti_join_cols,
1334 )?;
1335 }
1336 }
1337 }
1338 }
1339 let complement_cols: Vec<String> = if !batches.is_empty() {
1341 batches[0]
1342 .schema()
1343 .fields()
1344 .iter()
1345 .filter(|f| f.name().starts_with("__prob_complement_"))
1346 .map(|f| f.name().clone())
1347 .collect()
1348 } else {
1349 vec![]
1350 };
1351 if !complement_cols.is_empty() {
1352 batches = multiply_prob_factors(
1353 batches,
1354 rule.prob_column_name.as_deref(),
1355 &complement_cols,
1356 )?;
1357 }
1358
1359 clause_candidates.push(batches.clone());
1360 all_candidates.extend(batches);
1361 }
1362
1363 let changed = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
1367 states[rule_idx].merge_best_by(all_candidates, &rule.best_by_criteria)?
1368 } else {
1369 states[rule_idx]
1370 .merge_delta(all_candidates, Some(Arc::clone(&task_ctx)))
1371 .await?
1372 };
1373 if changed {
1374 any_changed = true;
1375 if let Some(ref tracker) = derivation_tracker {
1377 record_provenance(
1378 ProvenanceCtx {
1379 tracker,
1380 registry: ®istry,
1381 warnings_slot: &warnings_slot,
1382 },
1383 rule,
1384 &states[rule_idx],
1385 &clause_candidates,
1386 iteration,
1387 top_k_proofs,
1388 ClassifierRefs {
1389 registry: &classifier_registry,
1390 cache: classifier_cache.as_ref(),
1391 provenance_store: classifier_provenance_store.as_ref(),
1392 },
1393 )
1394 .await;
1395 }
1396 }
1397 }
1398
1399 if !any_changed && states.iter().all(|s| s.is_converged()) {
1401 tracing::debug!("fixpoint converged after {} iterations", iteration + 1);
1402 converged = true;
1403 break;
1404 }
1405
1406 if start.elapsed() > timeout {
1408 tracing::warn!(
1409 "fixpoint timeout after {} iterations; returning partial results",
1410 iteration + 1,
1411 );
1412 interruption::set(&timeout_flag, interruption::TIMEOUT);
1413 break;
1414 }
1415 }
1416
1417 if let Ok(mut counts) = iteration_counts.write() {
1419 for rule in &rules {
1420 counts.insert(rule.name.clone(), total_iters);
1421 }
1422 }
1423
1424 if !converged && interruption::reason(&timeout_flag).is_none() {
1429 tracing::warn!(
1430 "fixpoint did not converge after {max_iterations} iterations; returning partial results",
1431 );
1432 interruption::set(&timeout_flag, interruption::ITERATION_LIMIT);
1433 }
1434
1435 let task_ctx = session_ctx.read().task_ctx();
1437 let mut all_output = Vec::new();
1438
1439 for (rule_idx, state) in states.into_iter().enumerate() {
1440 let rule = &rules[rule_idx];
1441 let mut facts = state.into_facts();
1442 if facts.is_empty() {
1443 continue;
1444 }
1445
1446 let shared_info = if semiring_kind == SemiringKind::MaxMinProb {
1464 None
1465 } else if let Some(ref tracker) = derivation_tracker {
1466 detect_shared_lineage(rule, &facts, tracker, &warnings_slot, semiring_kind)
1467 } else {
1468 None
1469 };
1470
1471 if exact_probability
1473 && let Some(ref info) = shared_info
1474 && let Some(ref tracker) = derivation_tracker
1475 {
1476 facts = apply_exact_wmc(
1477 facts,
1478 rule,
1479 info,
1480 tracker,
1481 max_bdd_variables,
1482 &warnings_slot,
1483 &approximate_slot,
1484 )?;
1485 }
1486
1487 let processed = apply_post_fixpoint_chain(
1488 facts,
1489 rule,
1490 &task_ctx,
1491 strict_probability_domain,
1492 probability_epsilon,
1493 semiring_kind,
1494 derivation_tracker.as_ref().map(Arc::clone),
1495 top_k_proofs,
1496 Some(Arc::clone(®istry)),
1497 )
1498 .await?;
1499 all_output.extend(processed);
1500 }
1501
1502 if all_output.is_empty() {
1504 all_output.push(RecordBatch::new_empty(output_schema));
1505 }
1506
1507 Ok(all_output)
1508}
1509
1510pub(crate) struct ClassifierRefs<'a> {
1522 pub registry: &'a Arc<ClassifierRegistry>,
1523 pub cache: Option<&'a Arc<uni_locy::ModelInvocationCache>>,
1524 pub provenance_store: Option<&'a Arc<uni_locy::NeuralProvenanceStore>>,
1531}
1532
1533pub(crate) struct ProvenanceCtx<'a> {
1539 pub tracker: &'a Arc<ProvenanceStore>,
1540 pub registry: &'a Arc<DerivedScanRegistry>,
1541 pub warnings_slot: &'a Arc<StdRwLock<Vec<RuntimeWarning>>>,
1542}
1543
1544async fn record_provenance(
1545 prov: ProvenanceCtx<'_>,
1546 rule: &FixpointRulePlan,
1547 state: &FixpointState,
1548 clause_candidates: &[Vec<RecordBatch>],
1549 iteration: usize,
1550 top_k_proofs: usize,
1551 classifiers: ClassifierRefs<'_>,
1552) {
1553 let tracker = prov.tracker;
1554 let registry = prov.registry;
1555 let warnings_slot = prov.warnings_slot;
1556 let classifier_registry = classifiers.registry;
1557 let classifier_cache = classifiers.cache;
1558 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1559
1560 let base_probs = if top_k_proofs > 0 {
1562 tracker.base_fact_probs()
1563 } else {
1564 HashMap::new()
1565 };
1566
1567 let mut topk_acc = TopKProofAccumulator::new();
1568
1569 for delta_batch in state.all_delta() {
1570 for row_idx in 0..delta_batch.num_rows() {
1571 let row_hash = format!(
1572 "{:?}",
1573 extract_scalar_key(delta_batch, &all_indices, row_idx)
1574 )
1575 .into_bytes();
1576 let fact_row = batch_row_to_value_map(delta_batch, row_idx);
1577 let clause_index =
1578 find_clause_for_row(delta_batch, row_idx, &all_indices, clause_candidates);
1579
1580 let support = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
1581
1582 let proof_probability = if top_k_proofs > 0 {
1583 compute_proof_probability(&support, &base_probs)
1584 } else {
1585 None
1586 };
1587
1588 let entry = ProvenanceAnnotation {
1589 rule_name: rule.name.clone(),
1590 clause_index,
1591 support,
1592 along_values: {
1593 let along_names: Vec<String> = rule
1594 .clauses
1595 .get(clause_index)
1596 .map(|c| c.along_bindings.clone())
1597 .unwrap_or_default();
1598 along_names
1599 .iter()
1600 .filter_map(|name| fact_row.get(name).map(|v| (name.clone(), v.clone())))
1601 .collect()
1602 },
1603 iteration,
1604 fact_row: fact_row.clone(),
1605 proof_probability,
1606 neural_calls: collect_neural_calls_for_row(
1607 rule,
1608 clause_index,
1609 &fact_row,
1610 classifier_registry,
1611 classifier_cache,
1612 classifiers.provenance_store,
1613 )
1614 .await,
1615 };
1616 if top_k_proofs > 0 {
1617 topk_acc.accumulate(&entry, &row_hash);
1618 tracker.record_top_k(row_hash, entry, top_k_proofs);
1619 } else {
1620 tracker.record(row_hash, entry);
1621 }
1622 }
1623 }
1624
1625 topk_acc.emit_warning_if_any(rule, top_k_proofs, warnings_slot);
1626}
1627
1628struct TopKProofAccumulator {
1635 per_fact: HashMap<Vec<u8>, Vec<uni_locy::Proof>>,
1636 base_rv_interner: HashMap<Vec<u8>, uni_locy::BaseRv>,
1637 next_rv: u32,
1638}
1639
1640impl TopKProofAccumulator {
1641 fn new() -> Self {
1642 Self {
1643 per_fact: HashMap::new(),
1644 base_rv_interner: HashMap::new(),
1645 next_rv: 0,
1646 }
1647 }
1648
1649 fn accumulate(&mut self, entry: &ProvenanceAnnotation, row_hash: &[u8]) {
1650 let mut base_rvs = uni_locy::BaseRvSet::empty();
1651 for term in &entry.support {
1652 let rv = *self
1653 .base_rv_interner
1654 .entry(term.base_fact_id.clone())
1655 .or_insert_with(|| {
1656 let r = uni_locy::BaseRv(self.next_rv);
1657 self.next_rv += 1;
1658 r
1659 });
1660 base_rvs.insert(rv);
1661 }
1662 self.per_fact
1663 .entry(row_hash.to_vec())
1664 .or_default()
1665 .push(uni_locy::Proof {
1666 weight: entry.proof_probability.unwrap_or(0.0),
1667 base_rvs,
1668 neural_calls: Vec::new(),
1669 });
1670 }
1671
1672 fn emit_warning_if_any(
1673 &self,
1674 rule: &FixpointRulePlan,
1675 top_k_proofs: usize,
1676 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1677 ) {
1678 if top_k_proofs == 0 || self.per_fact.is_empty() {
1679 return;
1680 }
1681 let crossed_facts = self
1682 .per_fact
1683 .values()
1684 .filter(|proofs| {
1685 let (_kept, notice) =
1686 uni_locy::merge_top_k_runtime(Vec::new(), (*proofs).clone(), top_k_proofs);
1687 notice == uni_locy::PruneNotice::CrossedDependency
1688 })
1689 .count();
1690 if crossed_facts == 0 {
1691 return;
1692 }
1693 let Ok(mut w) = warnings_slot.write() else {
1694 return;
1695 };
1696 let already = w.iter().any(|rw| {
1697 matches!(
1698 rw.code,
1699 uni_locy::types::RuntimeWarningCode::TopKPruningCrossedDependency
1700 ) && rw.rule_name == rule.name
1701 });
1702 if already {
1703 return;
1704 }
1705 w.push(RuntimeWarning {
1706 code: uni_locy::types::RuntimeWarningCode::TopKPruningCrossedDependency,
1707 rule_name: rule.name.clone(),
1708 message: format!(
1709 "rule '{}': top-K proof pruning (k={}) discarded {} fact(s) \
1710 whose dependencies overlap retained proofs. The retained \
1711 top-{} under-counts the true joint probability for those \
1712 facts (Scallop, Huang et al. 2021). Increase k to recover.",
1713 rule.name, top_k_proofs, crossed_facts, top_k_proofs
1714 ),
1715 variable_count: None,
1716 key_group: None,
1717 });
1718 }
1719}
1720
1721async fn collect_neural_calls_for_row(
1743 rule: &FixpointRulePlan,
1744 clause_index: usize,
1745 fact_row: &uni_locy::FactRow,
1746 classifier_registry: &Arc<ClassifierRegistry>,
1747 classifier_cache: Option<&Arc<uni_locy::ModelInvocationCache>>,
1748 provenance_store: Option<&Arc<uni_locy::NeuralProvenanceStore>>,
1749) -> Vec<uni_locy::NeuralProvenance> {
1750 let Some(clause) = rule.clauses.get(clause_index) else {
1751 return Vec::new();
1752 };
1753 if clause.model_invocations.is_empty() {
1754 return Vec::new();
1755 }
1756 let mut out = Vec::with_capacity(clause.model_invocations.len());
1757 for invocation in &clause.model_invocations {
1758 let mut features = std::collections::HashMap::new();
1769 for (binding_name, feat_expr) in invocation
1770 .feature_names
1771 .iter()
1772 .zip(invocation.feature_exprs.iter())
1773 {
1774 features.insert(
1775 binding_name.clone(),
1776 eval_feature_expr_against_fact_row(feat_expr, fact_row),
1777 );
1778 }
1779 let input = uni_locy::ClassifyInput { features };
1780 let input_hash = input.stable_hash();
1781
1782 if let Some(store) = provenance_store
1789 && let Some(record) = store.get(&invocation.model_name, input_hash)
1790 {
1791 out.push(uni_locy::NeuralProvenance {
1792 model_name: invocation.model_name.clone(),
1793 raw_probability: record.raw_probability,
1794 calibrated_probability: record.calibrated_probability,
1795 confidence_band: record.confidence_band,
1796 });
1797 continue;
1798 }
1799
1800 let Some(classifier) = classifier_registry.get(&invocation.model_name) else {
1805 continue;
1806 };
1807 let raw = if let Some(v) =
1808 classifier_cache.and_then(|c| c.get(&invocation.model_name, input_hash))
1809 {
1810 v
1811 } else {
1812 match classifier.classify(std::slice::from_ref(&input)).await {
1813 Ok(probs) => {
1814 let v = probs.first().copied().unwrap_or(0.0);
1815 if let Some(c) = classifier_cache {
1816 c.insert(&invocation.model_name, input_hash, v);
1817 }
1818 v
1819 }
1820 Err(_) => continue,
1821 }
1822 };
1823 let calibrator = classifier.get_calibrator();
1824 let calibrated_probability = calibrator.as_ref().map(|_| raw);
1825 let confidence_band = calibrator.as_ref().and_then(|c| c.confidence_band(raw));
1826 out.push(uni_locy::NeuralProvenance {
1827 model_name: invocation.model_name.clone(),
1828 raw_probability: raw,
1829 calibrated_probability,
1830 confidence_band,
1831 });
1832 }
1833 out
1834}
1835
1836fn eval_feature_expr_against_fact_row(
1844 expr: &uni_cypher::ast::Expr,
1845 fact_row: &uni_locy::FactRow,
1846) -> uni_locy::FeatureValue {
1847 use uni_cypher::ast::Expr;
1848 use uni_locy::FeatureValue;
1849 let value_to_feature = |v: Option<&uni_common::Value>| -> FeatureValue {
1850 match v {
1851 Some(uni_common::Value::Float(f)) => FeatureValue::Float(*f),
1852 Some(uni_common::Value::Int(i)) => FeatureValue::Int(*i),
1853 Some(uni_common::Value::Bool(b)) => FeatureValue::Bool(*b),
1854 Some(uni_common::Value::String(s)) => FeatureValue::String(s.clone()),
1855 Some(uni_common::Value::Node(n)) => {
1856 FeatureValue::Int(n.vid.as_u64() as i64)
1858 }
1859 _ => FeatureValue::Null,
1860 }
1861 };
1862 let resolve_value = |sub: &Expr| -> uni_common::Value {
1866 match sub {
1867 Expr::Variable(name) => fact_row
1868 .get(name)
1869 .cloned()
1870 .unwrap_or(uni_common::Value::Null),
1871 Expr::Property(boxed, prop) if matches!(boxed.as_ref(), Expr::Variable(_)) => {
1872 let Expr::Variable(v) = boxed.as_ref() else {
1873 unreachable!()
1874 };
1875 let key = format!("{}.{}", v, prop);
1876 if let Some(val) = fact_row.get(&key) {
1877 return val.clone();
1878 }
1879 if let Some(uni_common::Value::Node(n)) = fact_row.get(v) {
1880 return n
1881 .properties
1882 .get(prop)
1883 .cloned()
1884 .unwrap_or(uni_common::Value::Null);
1885 }
1886 uni_common::Value::Null
1887 }
1888 Expr::Literal(lit) => lit.to_value(),
1889 Expr::List(items) => {
1890 let mut out = Vec::with_capacity(items.len());
1891 for it in items {
1892 out.push(match it {
1893 Expr::Literal(lit) => lit.to_value(),
1894 _ => uni_common::Value::Null,
1895 });
1896 }
1897 uni_common::Value::List(out)
1898 }
1899 _ => uni_common::Value::Null,
1900 }
1901 };
1902
1903 match expr {
1904 Expr::Variable(name) => value_to_feature(fact_row.get(name)),
1905 Expr::Property(boxed, prop) => {
1906 if let Expr::Variable(v) = boxed.as_ref() {
1907 let key = format!("{}.{}", v, prop);
1909 if let Some(val) = fact_row.get(&key) {
1910 return value_to_feature(Some(val));
1911 }
1912 let hidden_key = format!("__feat_{}_{}", v, prop);
1922 if let Some(val) = fact_row.get(&hidden_key) {
1923 return value_to_feature(Some(val));
1924 }
1925 if let Some(uni_common::Value::Node(n)) = fact_row.get(v) {
1929 return value_to_feature(n.properties.get(prop));
1930 }
1931 }
1932 FeatureValue::Null
1933 }
1934 Expr::FunctionCall { name, args, .. } if name == "similar_to" && args.len() == 2 => {
1935 let lv = resolve_value(&args[0]);
1936 let rv = resolve_value(&args[1]);
1937 match crate::query::similar_to::eval_similar_to_pure(&lv, &rv) {
1938 Ok(uni_common::Value::Float(f)) => FeatureValue::Float(f),
1939 _ => FeatureValue::Null,
1940 }
1941 }
1942 Expr::FunctionCall { name, .. }
1957 if matches!(
1958 name.as_str(),
1959 "degree_centrality"
1960 | "pagerank_score"
1961 | "closeness_centrality"
1962 | "betweenness_centrality"
1963 | "eigenvector_centrality"
1964 | "harmonic_centrality"
1965 | "katz_centrality"
1966 | "avg_neighbor"
1967 | "max_neighbor"
1968 | "sum_neighbor"
1969 ) =>
1970 {
1971 FeatureValue::Null
1972 }
1973 _ => FeatureValue::Null,
1974 }
1975}
1976
1977fn collect_is_ref_inputs(
1978 rule: &FixpointRulePlan,
1979 clause_index: usize,
1980 delta_batch: &RecordBatch,
1981 row_idx: usize,
1982 registry: &Arc<DerivedScanRegistry>,
1983) -> Vec<ProofTerm> {
1984 let clause = match rule.clauses.get(clause_index) {
1985 Some(c) => c,
1986 None => return vec![],
1987 };
1988
1989 let mut inputs = Vec::new();
1990 let delta_schema = delta_batch.schema();
1991
1992 for binding in &clause.is_ref_bindings {
1993 if binding.negated {
1994 continue;
1995 }
1996 if binding.provenance_join_cols.is_empty() {
1997 continue;
1998 }
1999
2000 let body_values: Vec<(String, ScalarKey)> = binding
2002 .provenance_join_cols
2003 .iter()
2004 .filter_map(|(body_col, _derived_col)| {
2005 let col_idx = delta_schema
2006 .fields()
2007 .iter()
2008 .position(|f| f.name() == body_col)?;
2009 let key = extract_scalar_key(delta_batch, &[col_idx], row_idx);
2010 Some((body_col.clone(), key.into_iter().next()?))
2011 })
2012 .collect();
2013
2014 if body_values.len() != binding.provenance_join_cols.len() {
2015 continue;
2016 }
2017
2018 let entry = match registry.get(binding.derived_scan_index) {
2020 Some(e) => e,
2021 None => continue,
2022 };
2023 let source_batches = entry.data.read();
2024 let source_schema = &entry.schema;
2025
2026 for src_batch in source_batches.iter() {
2028 let all_src_indices: Vec<usize> = (0..src_batch.num_columns()).collect();
2029 for src_row in 0..src_batch.num_rows() {
2030 let matches = binding.provenance_join_cols.iter().enumerate().all(
2031 |(i, (_body_col, derived_col))| {
2032 let src_col_idx = source_schema
2033 .fields()
2034 .iter()
2035 .position(|f| f.name() == derived_col);
2036 match src_col_idx {
2037 Some(idx) => {
2038 let src_key = extract_scalar_key(src_batch, &[idx], src_row);
2039 src_key.first() == Some(&body_values[i].1)
2040 }
2041 None => false,
2042 }
2043 },
2044 );
2045 if matches {
2046 let fact_hash = format!(
2047 "{:?}",
2048 extract_scalar_key(src_batch, &all_src_indices, src_row)
2049 )
2050 .into_bytes();
2051 inputs.push(ProofTerm {
2052 source_rule: binding.rule_name.clone(),
2053 base_fact_id: fact_hash,
2054 });
2055 }
2056 }
2057 }
2058 }
2059
2060 inputs
2061}
2062
2063fn collect_is_ref_inputs_for_body_row(
2085 rule: &FixpointRulePlan,
2086 delta_batch: &RecordBatch,
2087 row_idx: usize,
2088 registry: &Arc<DerivedScanRegistry>,
2089) -> Vec<ProofTerm> {
2090 let mut combined: Vec<ProofTerm> = Vec::new();
2091 for clause_index in 0..rule.clauses.len() {
2092 let part = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
2093 combined.extend(part);
2094 }
2095 combined
2096}
2097
2098#[expect(
2117 dead_code,
2118 reason = "Fields accessed via SharedLineageInfo in detect_shared_lineage"
2119)]
2120pub(crate) struct SharedGroupRow {
2121 pub fact_hash: Vec<u8>,
2122 pub lineage: HashSet<Vec<u8>>,
2123}
2124
2125pub(crate) struct SharedLineageInfo {
2127 pub shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>>,
2129}
2130
2131pub(crate) fn fact_hash_key(batch: &RecordBatch, all_indices: &[usize], row_idx: usize) -> Vec<u8> {
2133 format!("{:?}", extract_scalar_key(batch, all_indices, row_idx)).into_bytes()
2134}
2135
2136fn detect_shared_lineage(
2139 rule: &FixpointRulePlan,
2140 pre_fold_facts: &[RecordBatch],
2141 tracker: &Arc<ProvenanceStore>,
2142 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2143 semiring_kind: SemiringKind,
2144) -> Option<SharedLineageInfo> {
2145 use uni_locy::{RuntimeWarning, RuntimeWarningCode};
2146
2147 let has_prob_fold = rule
2152 .fold_bindings
2153 .iter()
2154 .any(|fb| fb.aggregate.is_probability_aggregate());
2155 if !has_prob_fold {
2156 return None;
2157 }
2158
2159 let key_indices = &rule.key_column_indices;
2161 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2162
2163 let mut groups: HashMap<Vec<ScalarKey>, Vec<Vec<u8>>> = HashMap::new();
2164 for batch in pre_fold_facts {
2165 for row_idx in 0..batch.num_rows() {
2166 let key = extract_scalar_key(batch, key_indices, row_idx);
2167 let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
2168 groups.entry(key).or_default().push(fact_hash);
2169 }
2170 }
2171
2172 let mut shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>> = HashMap::new();
2173 let mut any_shared = false;
2174
2175 for (key, fact_hashes) in &groups {
2177 if fact_hashes.len() < 2 {
2178 continue;
2179 }
2180
2181 let mut has_inputs = false;
2183 let mut per_row_bases: Vec<HashSet<Vec<u8>>> = Vec::new();
2184 for fh in fact_hashes {
2185 let bases = compute_lineage(fh, tracker, &mut HashSet::new());
2186 if let Some(entry) = tracker.lookup(fh)
2187 && !entry.support.is_empty()
2188 {
2189 has_inputs = true;
2190 }
2191 per_row_bases.push(bases);
2192 }
2193
2194 let shared_found = if has_inputs {
2195 let mut found = false;
2197 'outer: for i in 0..per_row_bases.len() {
2198 for j in (i + 1)..per_row_bases.len() {
2199 if !per_row_bases[i].is_disjoint(&per_row_bases[j]) {
2200 found = true;
2201 break 'outer;
2202 }
2203 }
2204 }
2205 found
2206 } else {
2207 fact_hashes.iter().any(|fh| {
2210 tracker.lookup(fh).is_some_and(|entry| {
2211 rule.clauses
2212 .get(entry.clause_index)
2213 .is_some_and(|clause| clause.is_ref_bindings.iter().any(|b| !b.negated))
2214 })
2215 })
2216 };
2217
2218 if shared_found {
2219 any_shared = true;
2220 let rows: Vec<SharedGroupRow> = fact_hashes
2222 .iter()
2223 .zip(per_row_bases)
2224 .map(|(fh, bases)| SharedGroupRow {
2225 fact_hash: fh.clone(),
2226 lineage: bases,
2227 })
2228 .collect();
2229 shared_groups.insert(key.clone(), rows);
2230 }
2231 }
2232
2233 {
2239 let mut input_to_groups: HashMap<Vec<u8>, HashSet<Vec<ScalarKey>>> = HashMap::new();
2240 for (key, fact_hashes) in &groups {
2241 for fh in fact_hashes {
2242 if let Some(entry) = tracker.lookup(fh) {
2243 for input in &entry.support {
2244 input_to_groups
2245 .entry(input.base_fact_id.clone())
2246 .or_default()
2247 .insert(key.clone());
2248 }
2249 }
2250 }
2251 }
2252 let has_cross_group = input_to_groups.values().any(|g| g.len() > 1);
2253 if has_cross_group && let Ok(mut warnings) = warnings_slot.write() {
2254 let already_warned = warnings.iter().any(|w| {
2255 w.code == RuntimeWarningCode::CrossGroupCorrelationNotExact
2256 && w.rule_name == rule.name
2257 });
2258 if !already_warned {
2259 let example =
2263 input_to_groups
2264 .iter()
2265 .find(|(_, g)| g.len() > 1)
2266 .map(|(input, groups)| {
2267 let short = input
2268 .iter()
2269 .take(8)
2270 .map(|b| format!("{:02x}", b))
2271 .collect::<String>();
2272 let mut group_strs: Vec<String> =
2273 groups.iter().map(|k| format!("{:?}", k)).collect();
2274 group_strs.sort();
2275 format!(
2276 "input {} shared by groups [{}]",
2277 short,
2278 group_strs.join(", ")
2279 )
2280 });
2281 let shared_variable_count =
2287 input_to_groups.values().filter(|g| g.len() > 1).count();
2288 warnings.push(RuntimeWarning {
2289 code: RuntimeWarningCode::CrossGroupCorrelationNotExact,
2290 message: format!(
2291 "Rule '{}': {} IS-ref base fact(s) are shared across different \
2292 KEY groups. BDD corrects per-group probabilities but cannot \
2293 account for cross-group correlations.",
2294 rule.name, shared_variable_count
2295 ),
2296 rule_name: rule.name.clone(),
2297 variable_count: Some(shared_variable_count),
2298 key_group: example,
2299 });
2300 }
2301 }
2302 }
2303
2304 if any_shared {
2305 let suppress_under_topk = matches!(semiring_kind, SemiringKind::TopKProofs { .. });
2315 if !suppress_under_topk && let Ok(mut warnings) = warnings_slot.write() {
2316 let already_warned = warnings.iter().any(|w| {
2317 w.code == RuntimeWarningCode::SharedProbabilisticDependency
2318 && w.rule_name == rule.name
2319 });
2320 if !already_warned {
2321 warnings.push(RuntimeWarning {
2322 code: RuntimeWarningCode::SharedProbabilisticDependency,
2323 message: format!(
2324 "Rule '{}' aggregates with MNOR/MPROD but some proof paths \
2325 share intermediate facts, violating the independence assumption. \
2326 Results may overestimate probability.",
2327 rule.name
2328 ),
2329 rule_name: rule.name.clone(),
2330 variable_count: None,
2331 key_group: None,
2332 });
2333 }
2334 }
2335 Some(SharedLineageInfo { shared_groups })
2336 } else {
2337 None
2338 }
2339}
2340
2341#[allow(
2349 clippy::too_many_arguments,
2350 reason = "context bundle would be over-engineering for one call site"
2351)]
2352pub(crate) async fn record_and_detect_lineage_nonrecursive(
2353 rule: &FixpointRulePlan,
2354 tagged_clause_facts: &[(usize, Vec<RecordBatch>)],
2355 tracker: &Arc<ProvenanceStore>,
2356 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2357 registry: &Arc<DerivedScanRegistry>,
2358 top_k_proofs: usize,
2359 classifiers: ClassifierRefs<'_>,
2360 semiring_kind: SemiringKind,
2361) -> Option<SharedLineageInfo> {
2362 let classifier_registry = classifiers.registry;
2363 let classifier_cache = classifiers.cache;
2364 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2365
2366 let base_probs = if top_k_proofs > 0 {
2368 tracker.base_fact_probs()
2369 } else {
2370 HashMap::new()
2371 };
2372
2373 let mut topk_acc = TopKProofAccumulator::new();
2374
2375 for (clause_index, batches) in tagged_clause_facts {
2377 for batch in batches {
2378 for row_idx in 0..batch.num_rows() {
2379 let row_hash = fact_hash_key(batch, &all_indices, row_idx);
2380 let fact_row = batch_row_to_value_map(batch, row_idx);
2381
2382 let support = collect_is_ref_inputs(rule, *clause_index, batch, row_idx, registry);
2383
2384 let proof_probability = if top_k_proofs > 0 {
2385 compute_proof_probability(&support, &base_probs)
2386 } else {
2387 None
2388 };
2389
2390 let entry = ProvenanceAnnotation {
2391 rule_name: rule.name.clone(),
2392 clause_index: *clause_index,
2393 support,
2394 along_values: {
2395 let along_names: Vec<String> = rule
2396 .clauses
2397 .get(*clause_index)
2398 .map(|c| c.along_bindings.clone())
2399 .unwrap_or_default();
2400 along_names
2401 .iter()
2402 .filter_map(|name| {
2403 fact_row.get(name).map(|v| (name.clone(), v.clone()))
2404 })
2405 .collect()
2406 },
2407 iteration: 0,
2408 fact_row: fact_row.clone(),
2409 proof_probability,
2410 neural_calls: collect_neural_calls_for_row(
2411 rule,
2412 *clause_index,
2413 &fact_row,
2414 classifier_registry,
2415 classifier_cache,
2416 classifiers.provenance_store,
2417 )
2418 .await,
2419 };
2420 if top_k_proofs > 0 {
2421 topk_acc.accumulate(&entry, &row_hash);
2422 tracker.record_top_k(row_hash, entry, top_k_proofs);
2423 } else {
2424 tracker.record(row_hash, entry);
2425 }
2426 }
2427 }
2428 }
2429
2430 topk_acc.emit_warning_if_any(rule, top_k_proofs, warnings_slot);
2431
2432 let all_facts: Vec<RecordBatch> = tagged_clause_facts
2434 .iter()
2435 .flat_map(|(_, batches)| batches.iter().cloned())
2436 .collect();
2437 detect_shared_lineage(rule, &all_facts, tracker, warnings_slot, semiring_kind)
2438}
2439
2440pub(crate) fn apply_exact_wmc(
2448 pre_fold_facts: Vec<RecordBatch>,
2449 rule: &FixpointRulePlan,
2450 shared_info: &SharedLineageInfo,
2451 tracker: &Arc<ProvenanceStore>,
2452 max_bdd_variables: usize,
2453 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2454 approximate_slot: &Arc<StdRwLock<HashMap<String, Vec<String>>>>,
2455) -> DFResult<Vec<RecordBatch>> {
2456 use crate::query::df_graph::locy_bdd::{SemiringOp, weighted_model_count};
2457 use uni_locy::{RuntimeWarning, RuntimeWarningCode};
2458
2459 let prob_fold = rule
2463 .fold_bindings
2464 .iter()
2465 .find(|fb| fb.aggregate.is_probability_aggregate());
2466 let prob_fold = match prob_fold {
2467 Some(f) => f,
2468 None => return Ok(pre_fold_facts),
2469 };
2470 let semiring_op = if prob_fold.aggregate.is_noisy_or() {
2471 SemiringOp::Disjunction
2472 } else {
2473 SemiringOp::Conjunction
2474 };
2475 let prob_col_idx = prob_fold.input_col_index;
2476 let prob_col_name = rule.yield_schema.field(prob_col_idx).name().clone();
2477
2478 let key_indices = &rule.key_column_indices;
2479 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2480
2481 let shared_keys: HashSet<Vec<ScalarKey>> = shared_info.shared_groups.keys().cloned().collect();
2483
2484 struct GroupAccum {
2487 base_facts: Vec<HashSet<Vec<u8>>>,
2488 base_probs: HashMap<Vec<u8>, f64>,
2489 representative: (usize, usize),
2491 row_locations: Vec<(usize, usize)>,
2492 }
2493
2494 let mut group_accums: HashMap<Vec<ScalarKey>, GroupAccum> = HashMap::new();
2495 let mut non_shared_rows: Vec<(usize, usize)> = Vec::new(); for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
2498 for row_idx in 0..batch.num_rows() {
2499 let key = extract_scalar_key(batch, key_indices, row_idx);
2500 if shared_keys.contains(&key) {
2501 let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
2502 let bases = compute_lineage(&fact_hash, tracker, &mut HashSet::new());
2503
2504 let accum = group_accums.entry(key).or_insert_with(|| GroupAccum {
2505 base_facts: Vec::new(),
2506 base_probs: HashMap::new(),
2507 representative: (batch_idx, row_idx),
2508 row_locations: Vec::new(),
2509 });
2510
2511 for bf in &bases {
2513 if !accum.base_probs.contains_key(bf)
2514 && let Some(entry) = tracker.lookup(bf)
2515 && let Some(val) = entry.fact_row.get(&prob_col_name)
2516 && let Some(p) = value_to_f64(val)
2517 {
2518 accum.base_probs.insert(bf.clone(), p);
2519 }
2520 }
2521
2522 accum.base_facts.push(bases);
2523 accum.row_locations.push((batch_idx, row_idx));
2524 } else {
2525 non_shared_rows.push((batch_idx, row_idx));
2526 }
2527 }
2528 }
2529
2530 let mut keep_rows: HashSet<(usize, usize)> = HashSet::new();
2533 let mut overrides: HashMap<(usize, usize), f64> = HashMap::new();
2535
2536 for &loc in &non_shared_rows {
2538 keep_rows.insert(loc);
2539 }
2540
2541 for (key, accum) in &group_accums {
2542 let bdd_result = weighted_model_count(
2543 &accum.base_facts,
2544 &accum.base_probs,
2545 semiring_op,
2546 max_bdd_variables,
2547 );
2548
2549 if bdd_result.approximated {
2550 if let Ok(mut warnings) = warnings_slot.write() {
2552 let key_desc = format!("{key:?}");
2553 let already_warned = warnings.iter().any(|w| {
2554 w.code == RuntimeWarningCode::BddLimitExceeded
2555 && w.rule_name == rule.name
2556 && w.key_group.as_deref() == Some(&key_desc)
2557 });
2558 if !already_warned {
2559 warnings.push(RuntimeWarning {
2560 code: RuntimeWarningCode::BddLimitExceeded,
2561 message: format!(
2562 "Rule '{}': BDD variable limit exceeded ({} > {}). \
2563 Falling back to independence-mode result.",
2564 rule.name, bdd_result.variable_count, max_bdd_variables
2565 ),
2566 rule_name: rule.name.clone(),
2567 variable_count: Some(bdd_result.variable_count),
2568 key_group: Some(key_desc),
2569 });
2570 }
2571 }
2572 if let Ok(mut approx) = approximate_slot.write() {
2573 let key_desc = format!("{key:?}");
2574 approx.entry(rule.name.clone()).or_default().push(key_desc);
2575 }
2576 for &loc in &accum.row_locations {
2578 keep_rows.insert(loc);
2579 }
2580 } else {
2581 keep_rows.insert(accum.representative);
2583 overrides.insert(accum.representative, bdd_result.probability);
2584 }
2585 }
2586
2587 let mut result_batches = Vec::new();
2589 for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
2590 let kept_indices: Vec<usize> = (0..batch.num_rows())
2591 .filter(|&row_idx| keep_rows.contains(&(batch_idx, row_idx)))
2592 .collect();
2593
2594 if kept_indices.is_empty() {
2595 continue;
2596 }
2597
2598 let indices = arrow::array::UInt32Array::from(
2599 kept_indices.iter().map(|&i| i as u32).collect::<Vec<_>>(),
2600 );
2601 let mut columns: Vec<arrow::array::ArrayRef> = batch
2602 .columns()
2603 .iter()
2604 .map(|col| arrow::compute::take(col, &indices, None))
2605 .collect::<Result<Vec<_>, _>>()
2606 .map_err(arrow_err)?;
2607
2608 let override_map: Vec<Option<f64>> = kept_indices
2610 .iter()
2611 .map(|&row_idx| overrides.get(&(batch_idx, row_idx)).copied())
2612 .collect();
2613
2614 if override_map.iter().any(|o| o.is_some()) && prob_col_idx < columns.len() {
2615 let existing_prob = columns[prob_col_idx]
2617 .as_any()
2618 .downcast_ref::<arrow::array::Float64Array>();
2619 let new_values: Vec<f64> = override_map
2620 .iter()
2621 .enumerate()
2622 .map(|(i, ov)| match ov {
2623 Some(p) => *p,
2624 None => existing_prob.map(|arr| arr.value(i)).unwrap_or(0.0),
2625 })
2626 .collect();
2627 columns[prob_col_idx] = Arc::new(arrow::array::Float64Array::from(new_values));
2628 }
2629
2630 let result_batch = RecordBatch::try_new(batch.schema(), columns).map_err(arrow_err)?;
2631 result_batches.push(result_batch);
2632 }
2633
2634 Ok(result_batches)
2635}
2636
2637fn value_to_f64(val: &uni_common::Value) -> Option<f64> {
2639 match val {
2640 uni_common::Value::Float(f) => Some(*f),
2641 uni_common::Value::Int(i) => Some(*i as f64),
2642 _ => None,
2643 }
2644}
2645
2646fn compute_lineage(
2653 fact_hash: &[u8],
2654 tracker: &Arc<ProvenanceStore>,
2655 visited: &mut HashSet<Vec<u8>>,
2656) -> HashSet<Vec<u8>> {
2657 if !visited.insert(fact_hash.to_vec()) {
2658 return HashSet::new(); }
2660
2661 match tracker.lookup(fact_hash) {
2662 Some(entry) if !entry.support.is_empty() => {
2663 let mut bases = HashSet::new();
2664 for input in &entry.support {
2665 let child_bases = compute_lineage(&input.base_fact_id, tracker, visited);
2666 bases.extend(child_bases);
2667 }
2668 bases
2669 }
2670 _ => {
2671 let mut set = HashSet::new();
2673 set.insert(fact_hash.to_vec());
2674 set
2675 }
2676 }
2677}
2678
2679fn find_clause_for_row(
2684 delta_batch: &RecordBatch,
2685 row_idx: usize,
2686 all_indices: &[usize],
2687 clause_candidates: &[Vec<RecordBatch>],
2688) -> usize {
2689 let target_key = extract_scalar_key(delta_batch, all_indices, row_idx);
2690 for (clause_idx, batches) in clause_candidates.iter().enumerate() {
2691 for batch in batches {
2692 if batch.num_columns() != all_indices.len() {
2693 continue;
2694 }
2695 for r in 0..batch.num_rows() {
2696 if extract_scalar_key(batch, all_indices, r) == target_key {
2697 return clause_idx;
2698 }
2699 }
2700 }
2701 }
2702 0
2703}
2704
2705fn batch_row_to_value_map(
2707 batch: &RecordBatch,
2708 row_idx: usize,
2709) -> std::collections::HashMap<String, Value> {
2710 use uni_store::storage::arrow_convert::arrow_to_value;
2711
2712 let schema = batch.schema();
2713 schema
2714 .fields()
2715 .iter()
2716 .enumerate()
2717 .map(|(col_idx, field)| {
2718 let col = batch.column(col_idx);
2719 let val = arrow_to_value(col.as_ref(), row_idx, None);
2720 (field.name().clone(), val)
2721 })
2722 .collect()
2723}
2724
2725pub fn apply_anti_join(
2730 batches: Vec<RecordBatch>,
2731 neg_facts: &[RecordBatch],
2732 left_col: &str,
2733 right_col: &str,
2734) -> datafusion::error::Result<Vec<RecordBatch>> {
2735 use arrow::compute::filter_record_batch;
2736 use arrow_array::{Array as _, BooleanArray, UInt64Array};
2737
2738 let mut banned: std::collections::HashSet<u64> = std::collections::HashSet::new();
2740 for batch in neg_facts {
2741 let Ok(idx) = batch.schema().index_of(right_col) else {
2742 continue;
2743 };
2744 let arr = batch.column(idx);
2745 let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2746 continue;
2747 };
2748 for i in 0..vids.len() {
2749 if !vids.is_null(i) {
2750 banned.insert(vids.value(i));
2751 }
2752 }
2753 }
2754
2755 if banned.is_empty() {
2756 return Ok(batches);
2757 }
2758
2759 let mut result = Vec::new();
2761 for batch in batches {
2762 let Ok(idx) = batch.schema().index_of(left_col) else {
2763 result.push(batch);
2764 continue;
2765 };
2766 let arr = batch.column(idx);
2767 let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2768 result.push(batch);
2769 continue;
2770 };
2771 let keep: Vec<bool> = (0..vids.len())
2772 .map(|i| vids.is_null(i) || !banned.contains(&vids.value(i)))
2773 .collect();
2774 let keep_arr = BooleanArray::from(keep);
2775 let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
2776 if filtered.num_rows() > 0 {
2777 result.push(filtered);
2778 }
2779 }
2780 Ok(result)
2781}
2782
2783#[allow(clippy::too_many_arguments)]
2804pub(crate) async fn apply_model_invocations(
2805 batches: Vec<RecordBatch>,
2806 invocations: &[uni_locy::ModelInvocation],
2807 registry: &Arc<ClassifierRegistry>,
2808 cache: Option<&Arc<uni_locy::ModelInvocationCache>>,
2809 provenance_store: Option<&Arc<uni_locy::NeuralProvenanceStore>>,
2810 path_context_handles: &HashMap<
2811 String,
2812 crate::query::df_graph::locy_model_invoke::PathContextHandle,
2813 >,
2814 xervo_runtime: &crate::query::df_graph::locy_model_invoke::XervoRuntimeHandle,
2815 graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
2816) -> DFResult<Vec<RecordBatch>> {
2817 use uni_locy::ClassifyInput;
2818 if batches.is_empty() || invocations.is_empty() {
2819 return Ok(batches);
2820 }
2821 let semantic_match_embeddings =
2825 pre_embed_semantic_match_queries(invocations, xervo_runtime).await?;
2826 let graph_feature_maps = precompute_graph_feature_maps(invocations, graph_algo).await?;
2831 let neighbor_feature_maps =
2832 precompute_neighbor_feature_maps(invocations, &batches, graph_algo).await?;
2833 let mut out_batches = Vec::with_capacity(batches.len());
2834 for batch in batches {
2835 let mut current = batch;
2836 for invocation in invocations {
2837 let classifier = registry.get(&invocation.model_name).ok_or_else(|| {
2838 datafusion::error::DataFusionError::Execution(format!(
2839 "neural classifier '{}' not registered; \
2840 add it to LocyConfig::classifier_registry",
2841 invocation.model_name
2842 ))
2843 })?;
2844
2845 let resolvers = build_feature_resolvers(
2857 ¤t,
2858 invocation,
2859 path_context_handles,
2860 &semantic_match_embeddings,
2861 &graph_feature_maps,
2862 &neighbor_feature_maps,
2863 )?;
2864
2865 let n_rows = current.num_rows();
2867 let mut inputs: Vec<ClassifyInput> = Vec::with_capacity(n_rows);
2868 let mut input_hashes: Vec<u64> = Vec::with_capacity(n_rows);
2869 for row_idx in 0..n_rows {
2870 let mut features = std::collections::HashMap::new();
2871 for resolver in &resolvers {
2872 let value = resolver.eval_row(¤t, row_idx)?;
2873 features.insert(resolver.binding_name.clone(), value);
2874 }
2875 let input = ClassifyInput { features };
2876 input_hashes.push(input.stable_hash());
2877 inputs.push(input);
2878 }
2879
2880 let mut probs: Vec<f64> = vec![0.0; n_rows];
2884 let mut miss_inputs: Vec<ClassifyInput> = Vec::new();
2885 let mut miss_row_indices: Vec<usize> = Vec::new();
2886 if let Some(c) = cache {
2887 for (row_idx, h) in input_hashes.iter().enumerate() {
2888 match c.get(&invocation.model_name, *h) {
2889 Some(v) => probs[row_idx] = v,
2890 None => {
2891 miss_row_indices.push(row_idx);
2892 miss_inputs.push(inputs[row_idx].clone());
2893 }
2894 }
2895 }
2896 } else {
2897 miss_row_indices = (0..n_rows).collect();
2898 miss_inputs = inputs.clone();
2899 }
2900
2901 let calibrator = classifier.get_calibrator();
2910 let (miss_raws, miss_calibrated) = if miss_inputs.is_empty() {
2911 (Vec::new(), Vec::new())
2912 } else if calibrator.is_some() {
2913 let pairs = classifier
2914 .raw_and_calibrated(&miss_inputs)
2915 .await
2916 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2917 if pairs.len() != miss_inputs.len() {
2918 return Err(datafusion::error::DataFusionError::Execution(format!(
2919 "classifier '{}' raw_and_calibrated returned {} outputs for {} inputs",
2920 invocation.model_name,
2921 pairs.len(),
2922 miss_inputs.len()
2923 )));
2924 }
2925 let raws: Vec<f64> = pairs.iter().map(|(r, _)| *r).collect();
2926 let cals: Vec<f64> = pairs.iter().map(|(r, c)| c.unwrap_or(*r)).collect();
2927 (raws, cals)
2928 } else {
2929 let r = classifier
2930 .classify(&miss_inputs)
2931 .await
2932 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2933 if r.len() != miss_inputs.len() {
2934 return Err(datafusion::error::DataFusionError::Execution(format!(
2935 "classifier '{}' returned {} outputs for {} inputs",
2936 invocation.model_name,
2937 r.len(),
2938 miss_inputs.len()
2939 )));
2940 }
2941 (r.clone(), r)
2943 };
2944 let mut row_raw: Vec<Option<f64>> = vec![None; n_rows];
2954 for (i, &row_idx) in miss_row_indices.iter().enumerate() {
2955 probs[row_idx] = miss_calibrated[i];
2956 row_raw[row_idx] = Some(miss_raws[i]);
2957 if let Some(c) = cache {
2958 c.insert(
2959 &invocation.model_name,
2960 input_hashes[row_idx],
2961 miss_calibrated[i],
2962 );
2963 }
2964 }
2965
2966 if let Some(store) = provenance_store {
2974 for row_idx in 0..n_rows {
2975 let calibrated_value = probs[row_idx];
2976 let (raw_value, calibrated) = match (row_raw[row_idx], &calibrator) {
2977 (Some(raw), Some(_)) => (raw, Some(calibrated_value)),
2978 (Some(raw), None) => (raw, None),
2979 (None, _) => (
2984 calibrated_value,
2985 calibrator.as_ref().map(|_| calibrated_value),
2986 ),
2987 };
2988 let band = calibrator
2989 .as_ref()
2990 .and_then(|c| c.confidence_band(calibrated_value));
2991 store.record(
2992 &invocation.model_name,
2993 input_hashes[row_idx],
2994 uni_locy::NeuralProvenanceRecord {
2995 raw_probability: raw_value,
2996 calibrated_probability: calibrated,
2997 confidence_band: band,
2998 feature_inputs: inputs[row_idx].features.clone(),
3006 },
3007 );
3008 }
3009 }
3010
3011 let out_col: Arc<dyn arrow_array::Array> =
3016 Arc::new(arrow_array::Float64Array::from(probs));
3017 let schema = current.schema();
3018 let target_idx = schema.index_of(&invocation.output_column).ok();
3019 let mut columns: Vec<Arc<dyn arrow_array::Array>> = current.columns().to_vec();
3020 let mut fields: Vec<Arc<arrow_schema::Field>> =
3021 schema.fields().iter().cloned().collect();
3022 match target_idx {
3023 Some(idx) => {
3024 columns[idx] = out_col;
3025 fields[idx] = Arc::new(arrow_schema::Field::new(
3028 &invocation.output_column,
3029 arrow_schema::DataType::Float64,
3030 true,
3031 ));
3032 }
3033 None => {
3034 columns.push(out_col);
3035 fields.push(Arc::new(arrow_schema::Field::new(
3036 &invocation.output_column,
3037 arrow_schema::DataType::Float64,
3038 true,
3039 )));
3040 }
3041 }
3042 let new_schema = Arc::new(arrow_schema::Schema::new(fields));
3043 current = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
3044 }
3045 out_batches.push(current);
3046 }
3047 Ok(out_batches)
3048}
3049
3050struct FeatureResolver {
3058 binding_name: String,
3059 kind: FeatureResolverKind,
3060}
3061
3062enum FeatureResolverKind {
3063 Direct(usize),
3064 SimilarTo {
3065 left: FeatureValueSrc,
3066 right: FeatureValueSrc,
3067 },
3068 PathContext {
3073 subject_col: usize,
3074 vid_to_value: Arc<HashMap<u64, uni_locy::FeatureValue>>,
3075 },
3076 GraphAlgoScore {
3081 subject_col: usize,
3082 vid_to_score: Arc<HashMap<u64, f64>>,
3083 },
3084 NeighborAggregate {
3090 subject_col: usize,
3091 op: NeighborAgg,
3092 vid_to_values: Arc<HashMap<u64, Vec<f64>>>,
3093 },
3094}
3095
3096#[derive(Debug, Clone, Copy)]
3097enum NeighborAgg {
3098 Avg,
3099 Max,
3100 Sum,
3101}
3102
3103#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3107enum NeighborDirection {
3108 Outgoing,
3109 Incoming,
3110 Both,
3111}
3112
3113impl NeighborDirection {
3114 fn store_directions(self) -> &'static [uni_store::storage::direction::Direction] {
3115 use uni_store::storage::direction::Direction;
3116 match self {
3117 NeighborDirection::Outgoing => &[Direction::Outgoing],
3118 NeighborDirection::Incoming => &[Direction::Incoming],
3119 NeighborDirection::Both => &[Direction::Outgoing, Direction::Incoming],
3120 }
3121 }
3122}
3123
3124impl NeighborAgg {
3125 fn from_fn_name(name: &str) -> Option<Self> {
3126 match name {
3127 "avg_neighbor" => Some(NeighborAgg::Avg),
3128 "max_neighbor" => Some(NeighborAgg::Max),
3129 "sum_neighbor" => Some(NeighborAgg::Sum),
3130 _ => None,
3131 }
3132 }
3133
3134 fn apply(self, values: &[f64]) -> Option<f64> {
3135 if values.is_empty() {
3136 return None;
3137 }
3138 match self {
3139 NeighborAgg::Avg => Some(values.iter().sum::<f64>() / values.len() as f64),
3140 NeighborAgg::Max => values.iter().copied().reduce(f64::max),
3141 NeighborAgg::Sum => Some(values.iter().sum()),
3142 }
3143 }
3144}
3145
3146enum FeatureValueSrc {
3149 Col(usize),
3150 Const(uni_common::Value),
3151}
3152
3153impl FeatureValueSrc {
3154 fn resolve(&self, batch: &RecordBatch, row_idx: usize) -> uni_common::Value {
3155 match self {
3156 FeatureValueSrc::Col(idx) => extract_common_value(batch.column(*idx).as_ref(), row_idx),
3157 FeatureValueSrc::Const(v) => v.clone(),
3158 }
3159 }
3160}
3161
3162impl FeatureResolver {
3163 fn eval_row(&self, batch: &RecordBatch, row_idx: usize) -> DFResult<uni_locy::FeatureValue> {
3164 match &self.kind {
3165 FeatureResolverKind::Direct(idx) => {
3166 Ok(extract_feature_value(batch.column(*idx).as_ref(), row_idx))
3167 }
3168 FeatureResolverKind::SimilarTo { left, right } => {
3169 let lv = left.resolve(batch, row_idx);
3170 let rv = right.resolve(batch, row_idx);
3171 match crate::query::similar_to::eval_similar_to_pure(&lv, &rv) {
3172 Ok(uni_common::Value::Float(f)) => Ok(uni_locy::FeatureValue::Float(f)),
3173 Ok(_) => Ok(uni_locy::FeatureValue::Null),
3174 Err(e) => Err(datafusion::error::DataFusionError::Execution(format!(
3175 "similar_to UDF failed: {e}"
3176 ))),
3177 }
3178 }
3179 FeatureResolverKind::PathContext {
3180 subject_col,
3181 vid_to_value,
3182 } => {
3183 let col = batch.column(*subject_col);
3184 if col.is_null(row_idx) {
3185 return Ok(uni_locy::FeatureValue::Null);
3186 }
3187 if let Some(arr) = col.as_any().downcast_ref::<arrow_array::UInt64Array>() {
3188 let vid = arr.value(row_idx);
3189 Ok(vid_to_value
3190 .get(&vid)
3191 .cloned()
3192 .unwrap_or(uni_locy::FeatureValue::Null))
3193 } else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3194 let vid = arr.value(row_idx) as u64;
3195 Ok(vid_to_value
3196 .get(&vid)
3197 .cloned()
3198 .unwrap_or(uni_locy::FeatureValue::Null))
3199 } else {
3200 Ok(uni_locy::FeatureValue::Null)
3201 }
3202 }
3203 FeatureResolverKind::GraphAlgoScore {
3204 subject_col,
3205 vid_to_score,
3206 } => {
3207 let col = batch.column(*subject_col);
3208 if col.is_null(row_idx) {
3209 return Ok(uni_locy::FeatureValue::Null);
3210 }
3211 let vid_opt: Option<u64> = if let Some(arr) =
3212 col.as_any().downcast_ref::<arrow_array::UInt64Array>()
3213 {
3214 Some(arr.value(row_idx))
3215 } else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3216 Some(arr.value(row_idx) as u64)
3217 } else {
3218 match extract_common_value(col.as_ref(), row_idx) {
3224 uni_common::Value::Node(n) => Some(n.vid.as_u64()),
3225 uni_common::Value::Int(i) => Some(i as u64),
3226 _ => None,
3227 }
3228 };
3229 Ok(vid_opt
3230 .and_then(|v| vid_to_score.get(&v).copied())
3231 .map(uni_locy::FeatureValue::Float)
3232 .unwrap_or(uni_locy::FeatureValue::Null))
3233 }
3234 FeatureResolverKind::NeighborAggregate {
3235 subject_col,
3236 op,
3237 vid_to_values,
3238 } => {
3239 let vid_opt = extract_vid_from_column(batch.column(*subject_col).as_ref(), row_idx);
3240 Ok(vid_opt
3241 .and_then(|v| vid_to_values.get(&v))
3242 .and_then(|values| op.apply(values))
3243 .map(uni_locy::FeatureValue::Float)
3244 .unwrap_or(uni_locy::FeatureValue::Null))
3245 }
3246 }
3247 }
3248}
3249
3250fn extract_vid_from_column(col: &dyn arrow_array::Array, row_idx: usize) -> Option<u64> {
3255 if col.is_null(row_idx) {
3256 return None;
3257 }
3258 if let Some(arr) = col.as_any().downcast_ref::<arrow_array::UInt64Array>() {
3259 return Some(arr.value(row_idx));
3260 }
3261 if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3262 return Some(arr.value(row_idx) as u64);
3263 }
3264 match extract_common_value(col, row_idx) {
3265 uni_common::Value::Node(n) => Some(n.vid.as_u64()),
3266 uni_common::Value::Int(i) => Some(i as u64),
3267 _ => None,
3268 }
3269}
3270
3271#[allow(clippy::too_many_arguments)]
3272fn build_feature_resolvers(
3273 batch: &RecordBatch,
3274 invocation: &uni_locy::ModelInvocation,
3275 path_context_handles: &HashMap<
3276 String,
3277 crate::query::df_graph::locy_model_invoke::PathContextHandle,
3278 >,
3279 semantic_match_embeddings: &HashMap<String, Vec<f32>>,
3280 graph_feature_maps: &HashMap<String, Arc<HashMap<u64, f64>>>,
3281 neighbor_feature_maps: &NeighborFeatureMaps,
3282) -> DFResult<Vec<FeatureResolver>> {
3283 use uni_cypher::ast::Expr;
3284 let schema = batch.schema();
3285 let lookup_col = |name_or_property: String| -> DFResult<usize> {
3286 schema.index_of(&name_or_property).map_err(|_| {
3287 datafusion::error::DataFusionError::Execution(format!(
3288 "feature column '{name_or_property}' not found in clause body output schema"
3289 ))
3290 })
3291 };
3292 let resolve_src = |expr: &Expr| -> DFResult<FeatureValueSrc> {
3297 match expr {
3298 Expr::Variable(name) => {
3299 let col = if schema.index_of(name).is_ok() {
3300 name.clone()
3301 } else {
3302 let vid_name = format!("{}._vid", name);
3303 if schema.index_of(&vid_name).is_ok() {
3304 vid_name
3305 } else {
3306 name.clone()
3307 }
3308 };
3309 Ok(FeatureValueSrc::Col(lookup_col(col)?))
3310 }
3311 Expr::Property(boxed, prop) if matches!(boxed.as_ref(), Expr::Variable(_)) => {
3312 let Expr::Variable(v) = boxed.as_ref() else {
3313 unreachable!()
3314 };
3315 let direct = format!("{}.{}", v, prop);
3316 let col = if schema.index_of(&direct).is_ok() {
3317 direct
3318 } else {
3319 format!("__feat_{}_{}", v, prop)
3320 };
3321 Ok(FeatureValueSrc::Col(lookup_col(col)?))
3322 }
3323 Expr::Literal(lit) => Ok(FeatureValueSrc::Const(lit.to_value())),
3324 Expr::List(items) => {
3325 let mut out = Vec::with_capacity(items.len());
3326 for it in items {
3327 out.push(match it {
3328 Expr::Literal(lit) => lit.to_value(),
3329 _ => uni_common::Value::Null,
3330 });
3331 }
3332 Ok(FeatureValueSrc::Const(uni_common::Value::List(out)))
3333 }
3334 other => Err(datafusion::error::DataFusionError::Execution(format!(
3335 "unsupported feature sub-expression: {other:?}"
3336 ))),
3337 }
3338 };
3339
3340 if let Some(pc) = &invocation.path_context {
3348 let handle = path_context_handles.get(&pc.source_rule).ok_or_else(|| {
3349 datafusion::error::DataFusionError::Execution(format!(
3350 "model '{}' path_context references rule '{}' but no DerivedScanHandle \
3351 was registered; this should never happen — the build_clause path \
3352 mints a handle for every distinct source_rule in the invocation set",
3353 invocation.model_name, pc.source_rule
3354 ))
3355 })?;
3356 let subject_col = schema
3357 .index_of(&format!("{}._vid", pc.subject_var))
3358 .or_else(|_| schema.index_of(&pc.subject_var))
3359 .map_err(|_| {
3360 datafusion::error::DataFusionError::Execution(format!(
3361 "model '{}' path_context: subject column '{}' (or '{0}._vid') not \
3362 in body batch schema",
3363 invocation.model_name, pc.subject_var
3364 ))
3365 })?;
3366 let vid_to_value =
3367 build_path_context_lookup(handle, &pc.subject_var, &pc.column, &invocation.model_name)?;
3368 return Ok(vec![FeatureResolver {
3369 binding_name: pc.column.clone(),
3370 kind: FeatureResolverKind::PathContext {
3371 subject_col,
3372 vid_to_value: Arc::new(vid_to_value),
3373 },
3374 }]);
3375 }
3376
3377 let mut out = Vec::with_capacity(invocation.feature_exprs.len());
3378 for (i, fexpr) in invocation.feature_exprs.iter().enumerate() {
3379 let binding_name = invocation.feature_names[i].clone();
3380 let kind = match fexpr {
3381 Expr::FunctionCall { name, args, .. } if name == "similar_to" => {
3382 if args.len() != 2 {
3383 return Err(datafusion::error::DataFusionError::Execution(format!(
3384 "similar_to expects 2 args, got {}",
3385 args.len()
3386 )));
3387 }
3388 FeatureResolverKind::SimilarTo {
3389 left: resolve_src(&args[0])?,
3390 right: resolve_src(&args[1])?,
3391 }
3392 }
3393 Expr::FunctionCall { name, args, .. } if name == "semantic_match" => {
3394 if args.len() != 2 {
3399 return Err(datafusion::error::DataFusionError::Execution(format!(
3400 "semantic_match expects 2 args, got {}",
3401 args.len()
3402 )));
3403 }
3404 let text = match &args[1] {
3405 Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3406 other => {
3407 return Err(datafusion::error::DataFusionError::Execution(format!(
3408 "semantic_match: 2nd arg must be a string literal, got {other:?}"
3409 )));
3410 }
3411 };
3412 let embedded = semantic_match_embeddings.get(&text).ok_or_else(|| {
3413 datafusion::error::DataFusionError::Execution(format!(
3414 "semantic_match: query text '{text}' was not pre-embedded. \
3415 This is a bug — `apply_model_invocations` should have \
3416 embedded all unique semantic_match texts up front. Most \
3417 likely the Xervo runtime is not configured (configure \
3418 via `LocyConfig::xervo_runtime` or its equivalent)."
3419 ))
3420 })?;
3421 let right_vec: Vec<f32> = embedded.clone();
3422 FeatureResolverKind::SimilarTo {
3423 left: resolve_src(&args[0])?,
3424 right: FeatureValueSrc::Const(uni_common::Value::Vector(right_vec)),
3425 }
3426 }
3427 Expr::FunctionCall { name, args, .. }
3428 if matches!(
3429 name.as_str(),
3430 "degree_centrality"
3431 | "pagerank_score"
3432 | "closeness_centrality"
3433 | "betweenness_centrality"
3434 | "eigenvector_centrality"
3435 | "harmonic_centrality"
3436 | "katz_centrality"
3437 ) =>
3438 {
3439 if args.len() != 1 {
3440 return Err(datafusion::error::DataFusionError::Execution(format!(
3441 "{name} expects 1 arg, got {}",
3442 args.len()
3443 )));
3444 }
3445 let Expr::Variable(v) = &args[0] else {
3446 return Err(datafusion::error::DataFusionError::Execution(format!(
3447 "{name}(...) argument must be a node variable, got {:?}",
3448 args[0]
3449 )));
3450 };
3451 let subject_col = {
3452 let direct = schema.index_of(v).ok();
3453 let vid_name = format!("{}._vid", v);
3454 let vid_col = schema.index_of(&vid_name).ok();
3455 vid_col.or(direct).ok_or_else(|| {
3456 datafusion::error::DataFusionError::Execution(format!(
3457 "{name}: subject column '{v}' (or '{v}._vid') not in body batch schema"
3458 ))
3459 })?
3460 };
3461 let vid_to_score = graph_feature_maps.get(name).cloned().ok_or_else(|| {
3462 datafusion::error::DataFusionError::Execution(format!(
3463 "{name}: pre-computed score map missing. This is a bug — \
3464 `apply_model_invocations` should have called \
3465 `precompute_graph_feature_maps` for every graph-structural \
3466 feature before building resolvers. Most likely the graph \
3467 algorithm registry is not configured."
3468 ))
3469 })?;
3470 FeatureResolverKind::GraphAlgoScore {
3471 subject_col,
3472 vid_to_score,
3473 }
3474 }
3475 Expr::FunctionCall { name, args, .. }
3476 if matches!(
3477 name.as_str(),
3478 "avg_neighbor" | "max_neighbor" | "sum_neighbor"
3479 ) =>
3480 {
3481 if args.len() != 3 && args.len() != 4 {
3482 return Err(datafusion::error::DataFusionError::Execution(format!(
3483 "{name} expects 3 or 4 args, got {}",
3484 args.len()
3485 )));
3486 }
3487 let Expr::Variable(v) = &args[0] else {
3488 return Err(datafusion::error::DataFusionError::Execution(format!(
3489 "{name}(...) first argument must be a node variable, got {:?}",
3490 args[0]
3491 )));
3492 };
3493 let rel_type = match &args[1] {
3494 Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3495 other => {
3496 return Err(datafusion::error::DataFusionError::Execution(format!(
3497 "{name}: 2nd arg must be a string literal (rel-type), got {other:?}"
3498 )));
3499 }
3500 };
3501 let prop_name = match &args[2] {
3502 Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3503 other => {
3504 return Err(datafusion::error::DataFusionError::Execution(format!(
3505 "{name}: 3rd arg must be a string literal (property), got {other:?}"
3506 )));
3507 }
3508 };
3509 let direction_arg = match args.get(3) {
3510 None => NeighborDirection::Outgoing,
3511 Some(Expr::Literal(uni_cypher::ast::CypherLiteral::String(d))) => {
3512 match d.to_uppercase().as_str() {
3513 "OUTGOING" => NeighborDirection::Outgoing,
3514 "INCOMING" => NeighborDirection::Incoming,
3515 "BOTH" => NeighborDirection::Both,
3516 other => {
3517 return Err(datafusion::error::DataFusionError::Execution(
3518 format!(
3519 "{name}: direction must be OUTGOING|INCOMING|BOTH, got '{other}'"
3520 ),
3521 ));
3522 }
3523 }
3524 }
3525 Some(other) => {
3526 return Err(datafusion::error::DataFusionError::Execution(format!(
3527 "{name}: 4th arg must be a string literal (direction), got {other:?}"
3528 )));
3529 }
3530 };
3531 let subject_col = {
3532 let direct = schema.index_of(v).ok();
3533 let vid_name = format!("{}._vid", v);
3534 let vid_col = schema.index_of(&vid_name).ok();
3535 vid_col.or(direct).ok_or_else(|| {
3536 datafusion::error::DataFusionError::Execution(format!(
3537 "{name}: subject column '{v}' (or '{v}._vid') not in body batch schema"
3538 ))
3539 })?
3540 };
3541 let vid_to_values = neighbor_feature_maps
3542 .get(&(rel_type.clone(), prop_name.clone(), direction_arg))
3543 .cloned()
3544 .ok_or_else(|| {
3545 datafusion::error::DataFusionError::Execution(format!(
3546 "{name}: pre-computed neighbor map missing for ({rel_type}, {prop_name}, {direction_arg:?}). \
3547 This is a bug — `apply_model_invocations` should have called \
3548 `precompute_neighbor_feature_maps` for every neighbor-aggregator \
3549 feature before building resolvers."
3550 ))
3551 })?;
3552 let op = NeighborAgg::from_fn_name(name).unwrap();
3553 FeatureResolverKind::NeighborAggregate {
3554 subject_col,
3555 op,
3556 vid_to_values,
3557 }
3558 }
3559 other => match resolve_src(other)? {
3560 FeatureValueSrc::Col(idx) => FeatureResolverKind::Direct(idx),
3561 FeatureValueSrc::Const(_) => {
3562 return Err(datafusion::error::DataFusionError::Execution(format!(
3563 "model '{}' feature must reference a variable or property — got a literal",
3564 invocation.model_name
3565 )));
3566 }
3567 },
3568 };
3569 out.push(FeatureResolver { binding_name, kind });
3570 }
3571 Ok(out)
3572}
3573
3574async fn pre_embed_semantic_match_queries(
3581 invocations: &[uni_locy::ModelInvocation],
3582 xervo_runtime: &crate::query::df_graph::locy_model_invoke::XervoRuntimeHandle,
3583) -> DFResult<HashMap<String, Vec<f32>>> {
3584 use uni_cypher::ast::{CypherLiteral, Expr};
3585 let mut needed: Vec<(String, String)> = Vec::new();
3595 for inv in invocations {
3596 let alias = inv
3597 .embedder_alias
3598 .clone()
3599 .unwrap_or_else(|| "default".to_string());
3600 for fexpr in &inv.feature_exprs {
3601 if let Expr::FunctionCall { name, args, .. } = fexpr
3602 && name == "semantic_match"
3603 && args.len() == 2
3604 && let Expr::Literal(CypherLiteral::String(s)) = &args[1]
3605 {
3606 let tuple = (s.clone(), alias.clone());
3607 if !needed.contains(&tuple) {
3608 needed.push(tuple);
3609 }
3610 }
3611 }
3612 }
3613 if needed.is_empty() {
3614 return Ok(HashMap::new());
3615 }
3616 let runtime = xervo_runtime.as_ref().ok_or_else(|| {
3617 datafusion::error::DataFusionError::Execution(
3618 "semantic_match: Uni-Xervo runtime not configured. Either provide \
3619 one via `LocyConfig::xervo_runtime` (or its equivalent setup \
3620 path) or pre-compute the query embedding and pass it via \
3621 `similar_to(prop, <literal_vector>)`."
3622 .to_string(),
3623 )
3624 })?;
3625 let mut by_alias: HashMap<String, Vec<String>> = HashMap::new();
3628 for (text, alias) in &needed {
3629 by_alias
3630 .entry(alias.clone())
3631 .or_default()
3632 .push(text.clone());
3633 }
3634 let mut out: HashMap<String, Vec<f32>> = HashMap::new();
3635 for (alias, texts) in by_alias {
3636 let embedder = runtime.embedding(&alias).await.map_err(|e| {
3637 datafusion::error::DataFusionError::Execution(format!(
3638 "semantic_match: failed to obtain embedder for alias '{alias}': {e}. \
3639 Register an embedder under that alias in your Uni-Xervo runtime, or \
3640 pre-compute the query embedding and pass via similar_to."
3641 ))
3642 })?;
3643 let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
3644 let embeddings = embedder.embed(text_refs).await.map_err(|e| {
3645 datafusion::error::DataFusionError::Execution(format!(
3646 "semantic_match: embedder '{alias}' call failed: {e}"
3647 ))
3648 })?;
3649 if embeddings.len() != texts.len() {
3650 return Err(datafusion::error::DataFusionError::Execution(format!(
3651 "semantic_match: embedder '{alias}' returned {} vectors for {} queries",
3652 embeddings.len(),
3653 texts.len()
3654 )));
3655 }
3656 for (text, vec) in texts.into_iter().zip(embeddings) {
3657 out.insert(text, vec);
3658 }
3659 }
3660 Ok(out)
3661}
3662
3663async fn precompute_graph_feature_maps(
3676 invocations: &[uni_locy::ModelInvocation],
3677 graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
3678) -> DFResult<HashMap<String, Arc<HashMap<u64, f64>>>> {
3679 use futures::StreamExt;
3680 use uni_algo::algo::procedures::AlgoContext;
3681 use uni_cypher::ast::Expr;
3682
3683 fn procedure_for(fn_name: &str) -> Option<&'static str> {
3686 match fn_name {
3687 "degree_centrality" => Some("uni.algo.degreeCentrality"),
3688 "pagerank_score" => Some("uni.algo.pageRank"),
3689 "closeness_centrality" => Some("uni.algo.closeness"),
3690 "betweenness_centrality" => Some("uni.algo.betweenness"),
3691 "eigenvector_centrality" => Some("uni.algo.eigenvectorCentrality"),
3692 "harmonic_centrality" => Some("uni.algo.harmonicCentrality"),
3693 "katz_centrality" => Some("uni.algo.katzCentrality"),
3694 _ => None,
3695 }
3696 }
3697
3698 let mut needed: Vec<String> = Vec::new();
3702 for inv in invocations {
3703 for fexpr in &inv.feature_exprs {
3704 if let Expr::FunctionCall { name, .. } = fexpr
3705 && procedure_for(name).is_some()
3706 && !needed.contains(name)
3707 {
3708 needed.push(name.clone());
3709 }
3710 }
3711 }
3712 if needed.is_empty() {
3713 return Ok(HashMap::new());
3714 }
3715
3716 let registry = graph_algo.registry.as_ref().ok_or_else(|| {
3717 datafusion::error::DataFusionError::Execution(
3718 "graph-structural FEATURE invoked but no `AlgorithmRegistry` is \
3719 configured. Configure one on `GraphExecutionContext::with_algo_registry`."
3720 .to_string(),
3721 )
3722 })?;
3723 let storage = graph_algo.storage.as_ref().ok_or_else(|| {
3724 datafusion::error::DataFusionError::Execution(
3725 "graph-structural FEATURE invoked but no storage handle was \
3726 threaded into the FEATURE runtime. This is a bug in df_planner."
3727 .to_string(),
3728 )
3729 })?;
3730
3731 let mut out: HashMap<String, Arc<HashMap<u64, f64>>> = HashMap::new();
3732 for fn_name in needed {
3733 let proc_name = procedure_for(&fn_name).unwrap();
3734 let procedure = registry.get(proc_name).ok_or_else(|| {
3735 datafusion::error::DataFusionError::Execution(format!(
3736 "graph-structural FEATURE '{fn_name}' resolves to procedure \
3737 '{proc_name}' which is not in the algorithm registry"
3738 ))
3739 })?;
3740 let args: Vec<serde_json::Value> = vec![
3745 serde_json::Value::Array(Vec::new()),
3746 serde_json::Value::Array(Vec::new()),
3747 ];
3748 let algo_ctx = AlgoContext::new(
3749 storage.clone(),
3750 graph_algo.l0_manager.as_ref().map(Arc::clone),
3751 );
3752 let filled_args = procedure
3773 .signature()
3774 .validate_args(args.clone())
3775 .map_err(|e| {
3776 datafusion::error::DataFusionError::Execution(format!(
3777 "graph-structural FEATURE '{fn_name}': argument validation failed: {e}"
3778 ))
3779 })?;
3780 let projection = uni_algo::algo::procedure_template::build_projection_from_direct_args(
3781 procedure.as_ref(),
3782 &algo_ctx,
3783 &filled_args,
3784 )
3785 .await
3786 .map_err(|e| {
3787 datafusion::error::DataFusionError::Execution(format!(
3788 "graph-structural FEATURE '{fn_name}': projection build failed: {e}"
3789 ))
3790 })?;
3791 let mut stream = procedure.execute_with_projection(algo_ctx, args, projection);
3792 let mut score_map: HashMap<u64, f64> = HashMap::new();
3793 let sig = procedure.signature();
3794 let node_idx = sig
3795 .yields
3796 .iter()
3797 .position(|(n, _)| *n == "nodeId")
3798 .ok_or_else(|| {
3799 datafusion::error::DataFusionError::Execution(format!(
3800 "procedure '{proc_name}' yield schema missing 'nodeId'"
3801 ))
3802 })?;
3803 let score_idx = sig
3808 .yields
3809 .iter()
3810 .position(|(n, _)| *n == "score" || *n == "centrality")
3811 .ok_or_else(|| {
3812 datafusion::error::DataFusionError::Execution(format!(
3813 "procedure '{proc_name}' yield schema missing a numeric score column \
3814 (expected 'score' or 'centrality')"
3815 ))
3816 })?;
3817 while let Some(row_res) = stream.next().await {
3818 let row = row_res.map_err(|e| {
3819 datafusion::error::DataFusionError::Execution(format!(
3820 "graph-structural FEATURE '{fn_name}': procedure '{proc_name}' failed: {e}"
3821 ))
3822 })?;
3823 let vid_v = row.values.get(node_idx);
3824 let score_v = row.values.get(score_idx);
3825 let (Some(vid_v), Some(score_v)) = (vid_v, score_v) else {
3826 continue;
3827 };
3828 let vid = vid_v.as_u64().or_else(|| vid_v.as_i64().map(|i| i as u64));
3829 let score = score_v
3830 .as_f64()
3831 .or_else(|| score_v.as_i64().map(|i| i as f64));
3832 if let (Some(vid), Some(score)) = (vid, score) {
3833 score_map.insert(vid, score);
3834 }
3835 }
3836 out.insert(fn_name, Arc::new(score_map));
3837 }
3838 Ok(out)
3839}
3840
3841type NeighborFeatureMaps =
3867 HashMap<(String, String, NeighborDirection), Arc<HashMap<u64, Vec<f64>>>>;
3868
3869async fn precompute_neighbor_feature_maps(
3870 invocations: &[uni_locy::ModelInvocation],
3871 batches: &[RecordBatch],
3872 graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
3873) -> DFResult<NeighborFeatureMaps> {
3874 use uni_cypher::ast::{CypherLiteral, Expr};
3875
3876 let parse_direction = |arg: Option<&Expr>| -> Option<NeighborDirection> {
3881 match arg {
3882 None => Some(NeighborDirection::Outgoing),
3883 Some(Expr::Literal(CypherLiteral::String(d))) => match d.to_uppercase().as_str() {
3884 "OUTGOING" => Some(NeighborDirection::Outgoing),
3885 "INCOMING" => Some(NeighborDirection::Incoming),
3886 "BOTH" => Some(NeighborDirection::Both),
3887 _ => None,
3888 },
3889 _ => None,
3890 }
3891 };
3892 let mut needed: Vec<(String, String, String, NeighborDirection)> = Vec::new();
3893 for inv in invocations {
3894 for fexpr in &inv.feature_exprs {
3895 if let Expr::FunctionCall { name, args, .. } = fexpr
3896 && NeighborAgg::from_fn_name(name).is_some()
3897 && (args.len() == 3 || args.len() == 4)
3898 && let Expr::Variable(v) = &args[0]
3899 && let Expr::Literal(CypherLiteral::String(rel)) = &args[1]
3900 && let Expr::Literal(CypherLiteral::String(prop)) = &args[2]
3901 && let Some(direction) = parse_direction(args.get(3))
3902 {
3903 let tuple = (v.clone(), rel.clone(), prop.clone(), direction);
3904 if !needed.contains(&tuple) {
3905 needed.push(tuple);
3906 }
3907 }
3908 }
3909 }
3910 if needed.is_empty() {
3911 return Ok(HashMap::new());
3912 }
3913
3914 let storage = graph_algo.storage.as_ref().ok_or_else(|| {
3915 datafusion::error::DataFusionError::Execution(
3916 "neighbor-aggregator FEATURE invoked but no storage handle was \
3917 threaded into the FEATURE runtime. This is a bug in df_planner."
3918 .to_string(),
3919 )
3920 })?;
3921 let property_manager = graph_algo.property_manager.as_ref().ok_or_else(|| {
3922 datafusion::error::DataFusionError::Execution(
3923 "neighbor-aggregator FEATURE invoked but no PropertyManager was \
3924 threaded into the FEATURE runtime. This is a bug in df_planner."
3925 .to_string(),
3926 )
3927 })?;
3928 let query_ctx = graph_algo.l0_buffers.as_ref().map(|bufs| {
3934 uni_store::runtime::context::QueryContext::new_with_pending(
3935 bufs.current.clone(),
3936 bufs.transaction.clone(),
3937 bufs.pending_flush.clone(),
3938 )
3939 });
3940
3941 let mut by_key: HashMap<(String, String, NeighborDirection), Vec<String>> = HashMap::new();
3945 for (subject_var, rel, prop, direction) in needed {
3946 by_key
3947 .entry((rel, prop, direction))
3948 .or_default()
3949 .push(subject_var);
3950 }
3951
3952 let mut out: NeighborFeatureMaps = HashMap::new();
3953 for ((rel_type, prop_name, direction), subject_vars) in by_key {
3954 let schema = storage.schema_manager().schema();
3956 let Some(edge_meta) = schema.edge_types.get(&rel_type) else {
3957 out.insert((rel_type, prop_name, direction), Arc::new(HashMap::new()));
3960 continue;
3961 };
3962 let edge_type_id = edge_meta.id;
3963
3964 let edge_ver = storage.get_edge_version_by_id(edge_type_id);
3967 for dir in direction.store_directions() {
3968 storage
3969 .warm_adjacency(edge_type_id, *dir, edge_ver)
3970 .await
3971 .map_err(|e| {
3972 datafusion::error::DataFusionError::Execution(format!(
3973 "neighbor-aggregator warm_adjacency for '{rel_type}' / {dir:?} failed: {e}"
3974 ))
3975 })?;
3976 }
3977
3978 let mut subject_vids: std::collections::HashSet<u64> = std::collections::HashSet::new();
3981 for subject_var in &subject_vars {
3982 for batch in batches {
3983 let schema = batch.schema();
3984 let col_idx = schema
3985 .index_of(&format!("{}._vid", subject_var))
3986 .ok()
3987 .or_else(|| schema.index_of(subject_var).ok());
3988 let Some(col_idx) = col_idx else { continue };
3989 let col = batch.column(col_idx);
3990 for row in 0..batch.num_rows() {
3991 if let Some(v) = extract_vid_from_column(col.as_ref(), row) {
3992 subject_vids.insert(v);
3993 }
3994 }
3995 }
3996 }
3997
3998 let mut vid_to_values: HashMap<u64, Vec<f64>> = HashMap::new();
4003 let adj = storage.adjacency_manager();
4004 for subject_vid in subject_vids {
4005 let mut neighbors: Vec<(uni_common::core::id::Vid, uni_common::core::id::Eid)> =
4006 Vec::new();
4007 for dir in direction.store_directions() {
4008 neighbors.extend(adj.get_neighbors(
4009 uni_common::core::id::Vid::from(subject_vid),
4010 edge_type_id,
4011 *dir,
4012 ));
4013 }
4014 let mut values: Vec<f64> = Vec::with_capacity(neighbors.len());
4015 for (neighbor_vid, _eid) in neighbors {
4016 let val = property_manager
4017 .get_vertex_prop_with_ctx(neighbor_vid, &prop_name, query_ctx.as_ref())
4018 .await
4019 .map_err(|e| {
4020 datafusion::error::DataFusionError::Execution(format!(
4021 "neighbor-aggregator: failed to read property \
4022 '{prop_name}' on neighbor vid {neighbor_vid:?}: {e}"
4023 ))
4024 })?;
4025 if let Some(f) = val.as_f64()
4026 && !f.is_nan()
4027 {
4028 values.push(f);
4029 }
4030 }
4031 vid_to_values.insert(subject_vid, values);
4032 }
4033 out.insert((rel_type, prop_name, direction), Arc::new(vid_to_values));
4034 }
4035 Ok(out)
4036}
4037
4038fn build_path_context_lookup(
4044 handle: &crate::query::df_graph::locy_model_invoke::PathContextHandle,
4045 _subject_var: &str,
4046 column: &str,
4047 model_name: &str,
4048) -> DFResult<HashMap<u64, uni_locy::FeatureValue>> {
4049 if handle.schema.fields().is_empty() {
4054 return Err(datafusion::error::DataFusionError::Execution(format!(
4055 "model '{model_name}' path_context: source rule has empty yield schema"
4056 )));
4057 }
4058 let subj_idx = 0_usize;
4059 let col_idx = handle.schema.index_of(column).map_err(|_| {
4060 datafusion::error::DataFusionError::Execution(format!(
4061 "model '{model_name}' path_context: column '{column}' not in \
4062 source rule's yield schema (have: {:?})",
4063 handle
4064 .schema
4065 .fields()
4066 .iter()
4067 .map(|f| f.name().clone())
4068 .collect::<Vec<_>>()
4069 ))
4070 })?;
4071 let batches = handle.data.read();
4072 let mut out: HashMap<u64, uni_locy::FeatureValue> = HashMap::new();
4073 for batch in batches.iter() {
4074 let subj_col = batch.column(subj_idx);
4075 let value_col = batch.column(col_idx);
4076 for row in 0..batch.num_rows() {
4077 if subj_col.is_null(row) {
4078 continue;
4079 }
4080 let vid = if let Some(a) = subj_col.as_any().downcast_ref::<arrow_array::UInt64Array>()
4081 {
4082 a.value(row)
4083 } else if let Some(a) = subj_col.as_any().downcast_ref::<arrow_array::Int64Array>() {
4084 a.value(row) as u64
4085 } else {
4086 continue;
4087 };
4088 let v = extract_feature_value(value_col.as_ref(), row);
4089 out.insert(vid, v);
4092 }
4093 }
4094 Ok(out)
4095}
4096
4097fn extract_common_value(col: &dyn arrow_array::Array, row_idx: usize) -> uni_common::Value {
4102 use arrow_array::{BooleanArray, Float64Array, Int64Array, LargeStringArray, StringArray};
4103 if col.is_null(row_idx) {
4104 return uni_common::Value::Null;
4105 }
4106 if let Some(a) = col.as_any().downcast_ref::<Float64Array>() {
4107 return uni_common::Value::Float(a.value(row_idx));
4108 }
4109 if let Some(a) = col.as_any().downcast_ref::<Int64Array>() {
4110 return uni_common::Value::Int(a.value(row_idx));
4111 }
4112 if let Some(a) = col.as_any().downcast_ref::<BooleanArray>() {
4113 return uni_common::Value::Bool(a.value(row_idx));
4114 }
4115 if let Some(a) = col.as_any().downcast_ref::<StringArray>() {
4116 return uni_common::Value::String(a.value(row_idx).to_string());
4117 }
4118 if let Some(a) = col.as_any().downcast_ref::<LargeStringArray>() {
4119 return uni_common::Value::String(a.value(row_idx).to_string());
4120 }
4121 if let Some(b) = col.as_any().downcast_ref::<arrow_array::LargeBinaryArray>() {
4122 let bytes = b.value(row_idx);
4123 if bytes.is_empty() {
4124 return uni_common::Value::Null;
4125 }
4126 return uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null);
4127 }
4128 uni_common::Value::Null
4129}
4130
4131fn extract_feature_value(col: &dyn arrow_array::Array, row_idx: usize) -> uni_locy::FeatureValue {
4132 use arrow_array::{BooleanArray, Float64Array, Int64Array, LargeStringArray, StringArray};
4133 if col.is_null(row_idx) {
4134 return uni_locy::FeatureValue::Null;
4135 }
4136 if let Some(a) = col.as_any().downcast_ref::<Float64Array>() {
4137 return uni_locy::FeatureValue::Float(a.value(row_idx));
4138 }
4139 if let Some(a) = col.as_any().downcast_ref::<Int64Array>() {
4140 return uni_locy::FeatureValue::Int(a.value(row_idx));
4141 }
4142 if let Some(a) = col.as_any().downcast_ref::<BooleanArray>() {
4143 return uni_locy::FeatureValue::Bool(a.value(row_idx));
4144 }
4145 if let Some(a) = col.as_any().downcast_ref::<StringArray>() {
4146 return uni_locy::FeatureValue::String(a.value(row_idx).to_string());
4147 }
4148 if let Some(a) = col.as_any().downcast_ref::<LargeStringArray>() {
4149 return uni_locy::FeatureValue::String(a.value(row_idx).to_string());
4150 }
4151 if let Some(b) = col.as_any().downcast_ref::<arrow_array::LargeBinaryArray>() {
4155 let bytes = b.value(row_idx);
4156 if bytes.is_empty() {
4157 return uni_locy::FeatureValue::Null;
4158 }
4159 let v = uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null);
4160 return match v {
4161 uni_common::Value::Float(f) => uni_locy::FeatureValue::Float(f),
4162 uni_common::Value::Int(i) => uni_locy::FeatureValue::Int(i),
4163 uni_common::Value::Bool(b) => uni_locy::FeatureValue::Bool(b),
4164 uni_common::Value::String(s) => uni_locy::FeatureValue::String(s),
4165 uni_common::Value::Null => uni_locy::FeatureValue::Null,
4166 _ => uni_locy::FeatureValue::Null,
4167 };
4168 }
4169 uni_locy::FeatureValue::Null
4170}
4171
4172pub fn apply_prob_complement(
4179 batches: Vec<RecordBatch>,
4180 neg_facts: &[RecordBatch],
4181 left_col: &str,
4182 right_col: &str,
4183 prob_col: &str,
4184 complement_col_name: &str,
4185) -> datafusion::error::Result<Vec<RecordBatch>> {
4186 use arrow_array::{Array as _, Float64Array, UInt64Array};
4187
4188 let mut prob_map: std::collections::HashMap<u64, f64> = std::collections::HashMap::new();
4190 for batch in neg_facts {
4191 let Ok(vid_idx) = batch.schema().index_of(right_col) else {
4192 continue;
4193 };
4194 let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
4195 continue;
4196 };
4197 let Some(vids) = batch.column(vid_idx).as_any().downcast_ref::<UInt64Array>() else {
4198 continue;
4199 };
4200 let prob_arr = batch.column(prob_idx);
4201 let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
4202 for i in 0..vids.len() {
4203 if !vids.is_null(i) {
4204 let p = probs
4205 .and_then(|arr| {
4206 if arr.is_null(i) {
4207 None
4208 } else {
4209 Some(arr.value(i))
4210 }
4211 })
4212 .unwrap_or(0.0);
4213 prob_map
4216 .entry(vids.value(i))
4217 .and_modify(|existing| {
4218 *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
4219 })
4220 .or_insert(p);
4221 }
4222 }
4223 }
4224
4225 let mut result = Vec::new();
4227 for batch in batches {
4228 let Ok(idx) = batch.schema().index_of(left_col) else {
4229 result.push(batch);
4230 continue;
4231 };
4232 let Some(vids) = batch.column(idx).as_any().downcast_ref::<UInt64Array>() else {
4233 result.push(batch);
4234 continue;
4235 };
4236
4237 let complements: Vec<f64> = (0..vids.len())
4239 .map(|i| {
4240 if vids.is_null(i) {
4241 1.0
4242 } else {
4243 let p = prob_map.get(&vids.value(i)).copied().unwrap_or(0.0);
4244 1.0 - p
4245 }
4246 })
4247 .collect();
4248
4249 let complement_arr = Float64Array::from(complements);
4250
4251 let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
4253 columns.push(std::sync::Arc::new(complement_arr));
4254
4255 let mut fields: Vec<std::sync::Arc<arrow_schema::Field>> =
4256 batch.schema().fields().iter().cloned().collect();
4257 fields.push(std::sync::Arc::new(arrow_schema::Field::new(
4258 complement_col_name,
4259 arrow_schema::DataType::Float64,
4260 true,
4261 )));
4262
4263 let new_schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4264 let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
4265 result.push(new_batch);
4266 }
4267 Ok(result)
4268}
4269
4270pub fn apply_prob_complement_composite(
4277 batches: Vec<RecordBatch>,
4278 neg_facts: &[RecordBatch],
4279 join_cols: &[(String, String)],
4280 prob_col: &str,
4281 complement_col_name: &str,
4282) -> datafusion::error::Result<Vec<RecordBatch>> {
4283 use arrow_array::{Array as _, Float64Array, UInt64Array};
4284
4285 let mut prob_map: HashMap<Vec<u64>, f64> = HashMap::new();
4287 for batch in neg_facts {
4288 let right_indices: Vec<usize> = join_cols
4289 .iter()
4290 .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
4291 .collect();
4292 if right_indices.len() != join_cols.len() {
4293 continue;
4294 }
4295 let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
4296 continue;
4297 };
4298 let prob_arr = batch.column(prob_idx);
4299 let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
4300 for row in 0..batch.num_rows() {
4301 let mut key = Vec::with_capacity(right_indices.len());
4302 let mut valid = true;
4303 for &ci in &right_indices {
4304 let col = batch.column(ci);
4305 if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
4306 if vids.is_null(row) {
4307 valid = false;
4308 break;
4309 }
4310 key.push(vids.value(row));
4311 } else {
4312 valid = false;
4313 break;
4314 }
4315 }
4316 if !valid {
4317 continue;
4318 }
4319 let p = probs
4320 .and_then(|arr| {
4321 if arr.is_null(row) {
4322 None
4323 } else {
4324 Some(arr.value(row))
4325 }
4326 })
4327 .unwrap_or(0.0);
4328 prob_map
4330 .entry(key)
4331 .and_modify(|existing| {
4332 *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
4333 })
4334 .or_insert(p);
4335 }
4336 }
4337
4338 let mut result = Vec::new();
4340 for batch in batches {
4341 let left_indices: Vec<usize> = join_cols
4342 .iter()
4343 .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
4344 .collect();
4345 if left_indices.len() != join_cols.len() {
4346 result.push(batch);
4347 continue;
4348 }
4349 let all_u64 = left_indices.iter().all(|&ci| {
4350 batch
4351 .column(ci)
4352 .as_any()
4353 .downcast_ref::<UInt64Array>()
4354 .is_some()
4355 });
4356 if !all_u64 {
4357 result.push(batch);
4358 continue;
4359 }
4360
4361 let complements: Vec<f64> = (0..batch.num_rows())
4362 .map(|row| {
4363 let mut key = Vec::with_capacity(left_indices.len());
4364 for &ci in &left_indices {
4365 let vids = batch
4366 .column(ci)
4367 .as_any()
4368 .downcast_ref::<UInt64Array>()
4369 .unwrap();
4370 if vids.is_null(row) {
4371 return 1.0;
4372 }
4373 key.push(vids.value(row));
4374 }
4375 let p = prob_map.get(&key).copied().unwrap_or(0.0);
4376 1.0 - p
4377 })
4378 .collect();
4379
4380 let complement_arr = Float64Array::from(complements);
4381 let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
4382 columns.push(Arc::new(complement_arr));
4383
4384 let mut fields: Vec<Arc<arrow_schema::Field>> =
4385 batch.schema().fields().iter().cloned().collect();
4386 fields.push(Arc::new(arrow_schema::Field::new(
4387 complement_col_name,
4388 arrow_schema::DataType::Float64,
4389 true,
4390 )));
4391
4392 let new_schema = Arc::new(arrow_schema::Schema::new(fields));
4393 let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
4394 result.push(new_batch);
4395 }
4396 Ok(result)
4397}
4398
4399pub fn apply_anti_join_composite(
4405 batches: Vec<RecordBatch>,
4406 neg_facts: &[RecordBatch],
4407 join_cols: &[(String, String)],
4408) -> datafusion::error::Result<Vec<RecordBatch>> {
4409 use arrow::compute::filter_record_batch;
4410 use arrow_array::{Array as _, BooleanArray, UInt64Array};
4411
4412 let mut banned: HashSet<Vec<u64>> = HashSet::new();
4414 for batch in neg_facts {
4415 let right_indices: Vec<usize> = join_cols
4416 .iter()
4417 .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
4418 .collect();
4419 if right_indices.len() != join_cols.len() {
4420 continue;
4421 }
4422 for row in 0..batch.num_rows() {
4423 let mut key = Vec::with_capacity(right_indices.len());
4424 let mut valid = true;
4425 for &ci in &right_indices {
4426 let col = batch.column(ci);
4427 if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
4428 if vids.is_null(row) {
4429 valid = false;
4430 break;
4431 }
4432 key.push(vids.value(row));
4433 } else {
4434 valid = false;
4435 break;
4436 }
4437 }
4438 if valid {
4439 banned.insert(key);
4440 }
4441 }
4442 }
4443
4444 if banned.is_empty() {
4445 return Ok(batches);
4446 }
4447
4448 let mut result = Vec::new();
4450 for batch in batches {
4451 let left_indices: Vec<usize> = join_cols
4452 .iter()
4453 .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
4454 .collect();
4455 if left_indices.len() != join_cols.len() {
4456 result.push(batch);
4457 continue;
4458 }
4459 let all_u64 = left_indices.iter().all(|&ci| {
4460 batch
4461 .column(ci)
4462 .as_any()
4463 .downcast_ref::<UInt64Array>()
4464 .is_some()
4465 });
4466 if !all_u64 {
4467 result.push(batch);
4468 continue;
4469 }
4470
4471 let keep: Vec<bool> = (0..batch.num_rows())
4472 .map(|row| {
4473 let mut key = Vec::with_capacity(left_indices.len());
4474 for &ci in &left_indices {
4475 let vids = batch
4476 .column(ci)
4477 .as_any()
4478 .downcast_ref::<UInt64Array>()
4479 .unwrap();
4480 if vids.is_null(row) {
4481 return true; }
4483 key.push(vids.value(row));
4484 }
4485 !banned.contains(&key)
4486 })
4487 .collect();
4488 let keep_arr = BooleanArray::from(keep);
4489 let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
4490 if filtered.num_rows() > 0 {
4491 result.push(filtered);
4492 }
4493 }
4494 Ok(result)
4495}
4496
4497pub fn multiply_prob_factors(
4508 batches: Vec<RecordBatch>,
4509 prob_col: Option<&str>,
4510 complement_cols: &[String],
4511) -> datafusion::error::Result<Vec<RecordBatch>> {
4512 use arrow_array::{Array as _, Float64Array};
4513
4514 let mut result = Vec::with_capacity(batches.len());
4515
4516 for batch in batches {
4517 if batch.num_rows() == 0 {
4518 let keep: Vec<usize> = batch
4520 .schema()
4521 .fields()
4522 .iter()
4523 .enumerate()
4524 .filter(|(_, f)| !complement_cols.contains(f.name()))
4525 .map(|(i, _)| i)
4526 .collect();
4527 let fields: Vec<_> = keep
4528 .iter()
4529 .map(|&i| batch.schema().field(i).clone())
4530 .collect();
4531 let cols: Vec<_> = keep.iter().map(|&i| batch.column(i).clone()).collect();
4532 let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4533 result.push(
4534 RecordBatch::try_new(schema, cols).map_err(|e| {
4535 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
4536 })?,
4537 );
4538 continue;
4539 }
4540
4541 let num_rows = batch.num_rows();
4542
4543 let mut combined = vec![1.0f64; num_rows];
4545 for col_name in complement_cols {
4546 if let Ok(idx) = batch.schema().index_of(col_name) {
4547 let arr = batch
4548 .column(idx)
4549 .as_any()
4550 .downcast_ref::<Float64Array>()
4551 .ok_or_else(|| {
4552 datafusion::error::DataFusionError::Internal(format!(
4553 "Expected Float64 for complement column {col_name}"
4554 ))
4555 })?;
4556 for (i, val) in combined.iter_mut().enumerate().take(num_rows) {
4557 if !arr.is_null(i) {
4558 *val *= arr.value(i);
4559 }
4560 }
4561 }
4562 }
4563
4564 let final_prob: Vec<f64> = if let Some(prob_name) = prob_col {
4566 if let Ok(idx) = batch.schema().index_of(prob_name) {
4567 let arr = batch
4568 .column(idx)
4569 .as_any()
4570 .downcast_ref::<Float64Array>()
4571 .ok_or_else(|| {
4572 datafusion::error::DataFusionError::Internal(format!(
4573 "Expected Float64 for PROB column {prob_name}"
4574 ))
4575 })?;
4576 (0..num_rows)
4577 .map(|i| {
4578 if arr.is_null(i) {
4579 combined[i]
4580 } else {
4581 arr.value(i) * combined[i]
4582 }
4583 })
4584 .collect()
4585 } else {
4586 combined
4587 }
4588 } else {
4589 combined
4590 };
4591
4592 let new_prob_array: arrow_array::ArrayRef =
4593 std::sync::Arc::new(Float64Array::from(final_prob));
4594
4595 let mut fields = Vec::new();
4597 let mut columns = Vec::new();
4598
4599 for (idx, field) in batch.schema().fields().iter().enumerate() {
4600 if complement_cols.contains(field.name()) {
4601 continue;
4602 }
4603 if prob_col.is_some_and(|p| field.name() == p) {
4604 fields.push(field.clone());
4605 columns.push(new_prob_array.clone());
4606 } else {
4607 fields.push(field.clone());
4608 columns.push(batch.column(idx).clone());
4609 }
4610 }
4611
4612 let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4613 result.push(RecordBatch::try_new(schema, columns).map_err(arrow_err)?);
4614 }
4615
4616 Ok(result)
4617}
4618
4619fn update_derived_scan_handles(
4624 registry: &DerivedScanRegistry,
4625 states: &[FixpointState],
4626 current_rule_idx: usize,
4627 rules: &[FixpointRulePlan],
4628) {
4629 let current_rule_name = &rules[current_rule_idx].name;
4630
4631 for entry in ®istry.entries {
4632 let source_state_idx = rules.iter().position(|r| r.name == entry.rule_name);
4634 let Some(source_idx) = source_state_idx else {
4635 continue;
4636 };
4637
4638 let is_self = entry.rule_name == *current_rule_name;
4639 let data = if is_self {
4640 states[source_idx].all_delta().to_vec()
4642 } else {
4643 states[source_idx].all_facts().to_vec()
4645 };
4646
4647 let data = if data.is_empty() || data.iter().all(|b| b.num_rows() == 0) {
4649 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
4650 } else {
4651 data
4652 };
4653
4654 let mut guard = entry.data.write();
4655 *guard = data;
4656 }
4657}
4658
4659pub struct DerivedScanExec {
4669 data: Arc<RwLock<Vec<RecordBatch>>>,
4670 schema: SchemaRef,
4671 properties: Arc<PlanProperties>,
4672}
4673
4674impl DerivedScanExec {
4675 pub fn new(data: Arc<RwLock<Vec<RecordBatch>>>, schema: SchemaRef) -> Self {
4676 let properties = compute_plan_properties(Arc::clone(&schema));
4677 Self {
4678 data,
4679 schema,
4680 properties,
4681 }
4682 }
4683}
4684
4685impl fmt::Debug for DerivedScanExec {
4686 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4687 f.debug_struct("DerivedScanExec")
4688 .field("schema", &self.schema)
4689 .finish()
4690 }
4691}
4692
4693impl DisplayAs for DerivedScanExec {
4694 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4695 write!(f, "DerivedScanExec")
4696 }
4697}
4698
4699impl ExecutionPlan for DerivedScanExec {
4700 fn name(&self) -> &str {
4701 "DerivedScanExec"
4702 }
4703 fn as_any(&self) -> &dyn Any {
4704 self
4705 }
4706 fn schema(&self) -> SchemaRef {
4707 Arc::clone(&self.schema)
4708 }
4709 fn properties(&self) -> &Arc<PlanProperties> {
4710 &self.properties
4711 }
4712 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
4713 vec![]
4714 }
4715 fn with_new_children(
4716 self: Arc<Self>,
4717 _children: Vec<Arc<dyn ExecutionPlan>>,
4718 ) -> DFResult<Arc<dyn ExecutionPlan>> {
4719 Ok(self)
4720 }
4721 fn execute(
4722 &self,
4723 _partition: usize,
4724 _context: Arc<TaskContext>,
4725 ) -> DFResult<SendableRecordBatchStream> {
4726 let batches = {
4727 let guard = self.data.read();
4728 if guard.is_empty() {
4729 vec![RecordBatch::new_empty(Arc::clone(&self.schema))]
4730 } else {
4731 guard.clone()
4732 }
4733 };
4734 Ok(Box::pin(MemoryStream::try_new(
4735 batches,
4736 Arc::clone(&self.schema),
4737 None,
4738 )?))
4739 }
4740}
4741
4742struct InMemoryExec {
4751 batches: Vec<RecordBatch>,
4752 schema: SchemaRef,
4753 properties: Arc<PlanProperties>,
4754}
4755
4756impl InMemoryExec {
4757 fn new(batches: Vec<RecordBatch>, schema: SchemaRef) -> Self {
4758 let properties = compute_plan_properties(Arc::clone(&schema));
4759 Self {
4760 batches,
4761 schema,
4762 properties,
4763 }
4764 }
4765}
4766
4767impl fmt::Debug for InMemoryExec {
4768 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4769 f.debug_struct("InMemoryExec")
4770 .field("num_batches", &self.batches.len())
4771 .field("schema", &self.schema)
4772 .finish()
4773 }
4774}
4775
4776impl DisplayAs for InMemoryExec {
4777 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4778 write!(f, "InMemoryExec: batches={}", self.batches.len())
4779 }
4780}
4781
4782impl ExecutionPlan for InMemoryExec {
4783 fn name(&self) -> &str {
4784 "InMemoryExec"
4785 }
4786 fn as_any(&self) -> &dyn Any {
4787 self
4788 }
4789 fn schema(&self) -> SchemaRef {
4790 Arc::clone(&self.schema)
4791 }
4792 fn properties(&self) -> &Arc<PlanProperties> {
4793 &self.properties
4794 }
4795 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
4796 vec![]
4797 }
4798 fn with_new_children(
4799 self: Arc<Self>,
4800 _children: Vec<Arc<dyn ExecutionPlan>>,
4801 ) -> DFResult<Arc<dyn ExecutionPlan>> {
4802 Ok(self)
4803 }
4804 fn execute(
4805 &self,
4806 _partition: usize,
4807 _context: Arc<TaskContext>,
4808 ) -> DFResult<SendableRecordBatchStream> {
4809 Ok(Box::pin(MemoryStream::try_new(
4810 self.batches.clone(),
4811 Arc::clone(&self.schema),
4812 None,
4813 )?))
4814 }
4815}
4816
4817fn apply_having_filter(
4827 batches: Vec<RecordBatch>,
4828 having_exprs: &[Expr],
4829 schema: &SchemaRef,
4830 task_ctx: &Arc<TaskContext>,
4831) -> DFResult<Vec<RecordBatch>> {
4832 use arrow::compute::{and, filter_record_batch};
4833 use arrow_array::BooleanArray;
4834 use datafusion::common::DFSchema;
4835 use datafusion::logical_expr::LogicalPlanBuilder;
4836 use datafusion::logical_expr::execution_props::ExecutionProps;
4837 use datafusion::optimizer::AnalyzerRule;
4838 use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
4839 use datafusion::physical_expr::create_physical_expr;
4840
4841 if batches.is_empty() {
4842 return Ok(batches);
4843 }
4844
4845 let df_schema = DFSchema::try_from(schema.as_ref().clone()).map_err(|e| {
4847 datafusion::common::DataFusionError::Internal(format!("HAVING schema conversion: {e}"))
4848 })?;
4849
4850 let config = (**task_ctx.session_config().options()).clone();
4855 let props = ExecutionProps::new();
4856
4857 let physical_exprs: Vec<Arc<dyn datafusion::physical_expr::PhysicalExpr>> = having_exprs
4863 .iter()
4864 .map(|expr| {
4865 let df_expr = crate::query::df_expr::cypher_expr_to_df(expr, None).map_err(|e| {
4866 datafusion::common::DataFusionError::Internal(format!(
4867 "HAVING expression conversion: {e}"
4868 ))
4869 })?;
4870
4871 let empty = datafusion::logical_expr::LogicalPlan::EmptyRelation(
4875 datafusion::logical_expr::EmptyRelation {
4876 produce_one_row: false,
4877 schema: Arc::new(df_schema.clone()),
4878 },
4879 );
4880 let filter_plan = LogicalPlanBuilder::from(empty)
4881 .filter(df_expr.clone())?
4882 .build()?;
4883 let coerced_expr = match TypeCoercion::new().analyze(filter_plan, &config) {
4884 Ok(datafusion::logical_expr::LogicalPlan::Filter(f)) => f.predicate,
4885 _ => df_expr,
4886 };
4887
4888 create_physical_expr(&coerced_expr, &df_schema, &props)
4889 })
4890 .collect::<DFResult<Vec<_>>>()?;
4891
4892 let mut result = Vec::new();
4893 for batch in batches {
4894 let mut mask: Option<BooleanArray> = None;
4896 for phys_expr in &physical_exprs {
4897 let value = phys_expr.evaluate(&batch)?;
4898 let arr = value.into_array(batch.num_rows())?;
4899 let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().ok_or_else(|| {
4900 datafusion::common::DataFusionError::Internal(
4901 "HAVING condition must evaluate to boolean".into(),
4902 )
4903 })?;
4904 mask = Some(match mask {
4905 None => bool_arr.clone(),
4906 Some(prev) => and(&prev, bool_arr).map_err(arrow_err)?,
4907 });
4908 }
4909 if let Some(ref m) = mask {
4910 let filtered = filter_record_batch(&batch, m).map_err(arrow_err)?;
4911 if filtered.num_rows() > 0 {
4912 result.push(filtered);
4913 }
4914 } else {
4915 result.push(batch);
4916 }
4917 }
4918 Ok(result)
4919}
4920
4921#[allow(
4923 clippy::too_many_arguments,
4924 reason = "context bundle would be over-engineering for one call site"
4925)]
4926pub(crate) async fn apply_post_fixpoint_chain(
4927 facts: Vec<RecordBatch>,
4928 rule: &FixpointRulePlan,
4929 task_ctx: &Arc<TaskContext>,
4930 strict_probability_domain: bool,
4931 probability_epsilon: f64,
4932 semiring_kind: SemiringKind,
4933 provenance_tracker: Option<Arc<ProvenanceStore>>,
4934 top_k_proofs_k: usize,
4935 registry: Option<Arc<DerivedScanRegistry>>,
4936) -> DFResult<Vec<RecordBatch>> {
4937 if !rule.has_fold && !rule.has_best_by && !rule.has_priority && rule.having.is_empty() {
4938 return Ok(facts);
4939 }
4940
4941 let schema = facts
4946 .iter()
4947 .find(|b| b.num_rows() > 0)
4948 .map(|b| b.schema())
4949 .unwrap_or_else(|| Arc::clone(&rule.yield_schema));
4950
4951 let topk_k: Option<usize> = match semiring_kind {
4965 SemiringKind::TopKProofs { k } if k > 0 => Some(k as usize),
4966 _ => None,
4967 };
4968 let body_support_map: Option<Arc<HashMap<Vec<u8>, Vec<ProofTerm>>>> = if topk_k.is_some()
4969 && !rule.has_priority
4970 && let Some(registry) = registry.as_ref()
4971 {
4972 let mut map: HashMap<Vec<u8>, Vec<ProofTerm>> = HashMap::new();
4973 for batch in &facts {
4974 let all_indices: Vec<usize> = (0..batch.num_columns()).collect();
4975 for row_idx in 0..batch.num_rows() {
4976 let support = collect_is_ref_inputs_for_body_row(rule, batch, row_idx, registry);
4977 if support.is_empty() {
4978 continue;
4979 }
4980 let hash = fact_hash_key(batch, &all_indices, row_idx);
4981 map.insert(hash, support);
4982 }
4983 }
4984 if map.is_empty() {
4985 None
4986 } else {
4987 Some(Arc::new(map))
4988 }
4989 } else {
4990 None
4991 };
4992
4993 let input: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(facts, schema.clone()));
4994
4995 let key_column_indices: Vec<usize> = rule
5000 .key_column_indices
5001 .iter()
5002 .filter_map(|&i| {
5003 let name = rule.yield_schema.field(i).name();
5004 schema.index_of(name).ok()
5005 })
5006 .collect();
5007
5008 let current: Arc<dyn ExecutionPlan> = if rule.has_priority {
5012 let priority_schema = input.schema();
5013 let priority_idx = priority_schema.index_of("__priority").map_err(|_| {
5014 datafusion::common::DataFusionError::Internal(
5015 "PRIORITY rule missing __priority column".to_string(),
5016 )
5017 })?;
5018 Arc::new(PriorityExec::new(
5019 input,
5020 key_column_indices.clone(),
5021 priority_idx,
5022 ))
5023 } else {
5024 input
5025 };
5026
5027 let current: Arc<dyn ExecutionPlan> = if rule.has_fold && !rule.fold_bindings.is_empty() {
5029 Arc::new(FoldExec::new_with_topk(
5030 current,
5031 key_column_indices.clone(),
5032 rule.fold_bindings.clone(),
5033 strict_probability_domain,
5034 probability_epsilon,
5035 semiring_kind,
5036 provenance_tracker.clone(),
5037 topk_k.unwrap_or(top_k_proofs_k),
5038 body_support_map.clone(),
5039 ))
5040 } else {
5041 current
5042 };
5043
5044 let current: Arc<dyn ExecutionPlan> = if !rule.having.is_empty() {
5046 let batches = collect_all_partitions(¤t, Arc::clone(task_ctx)).await?;
5047 let filtered = apply_having_filter(batches, &rule.having, ¤t.schema(), task_ctx)?;
5048 if filtered.is_empty() {
5049 return Ok(filtered);
5050 }
5051 Arc::new(InMemoryExec::new(filtered, Arc::clone(¤t.schema())))
5052 } else {
5053 current
5054 };
5055
5056 let current: Arc<dyn ExecutionPlan> = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
5058 Arc::new(BestByExec::new(
5059 current,
5060 key_column_indices.clone(),
5061 rule.best_by_criteria.clone(),
5062 rule.deterministic,
5063 ))
5064 } else {
5065 current
5066 };
5067
5068 collect_all_partitions(¤t, Arc::clone(task_ctx)).await
5069}
5070
5071pub struct FixpointExec {
5080 rules: Vec<FixpointRulePlan>,
5081 max_iterations: usize,
5082 timeout: Duration,
5083 graph_ctx: Arc<GraphExecutionContext>,
5084 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5085 storage: Arc<StorageManager>,
5086 schema_info: Arc<UniSchema>,
5087 params: HashMap<String, Value>,
5088 derived_scan_registry: Arc<DerivedScanRegistry>,
5089 output_schema: SchemaRef,
5090 properties: Arc<PlanProperties>,
5091 metrics: ExecutionPlanMetricsSet,
5092 max_derived_bytes: usize,
5093 derivation_tracker: Option<Arc<ProvenanceStore>>,
5095 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5097 strict_probability_domain: bool,
5098 probability_epsilon: f64,
5099 exact_probability: bool,
5100 max_bdd_variables: usize,
5101 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5103 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5105 top_k_proofs: usize,
5107 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5109 semiring_kind: SemiringKind,
5111 classifier_registry: Arc<ClassifierRegistry>,
5115 classifier_cache: Option<Arc<ModelInvocationCache>>,
5121 #[allow(
5128 dead_code,
5129 reason = "boundary plumbing; read by EXPLAIN via LocyModelInvokeExec"
5130 )]
5131 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
5132}
5133
5134impl fmt::Debug for FixpointExec {
5135 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5136 f.debug_struct("FixpointExec")
5137 .field("rules_count", &self.rules.len())
5138 .field("max_iterations", &self.max_iterations)
5139 .field("timeout", &self.timeout)
5140 .field("output_schema", &self.output_schema)
5141 .field("max_derived_bytes", &self.max_derived_bytes)
5142 .finish_non_exhaustive()
5143 }
5144}
5145
5146impl FixpointExec {
5147 #[expect(
5149 clippy::too_many_arguments,
5150 reason = "FixpointExec configuration needs all context"
5151 )]
5152 #[deprecated(
5153 note = "use `new_with_semiring_classifiers_and_cache` (or the lighter \
5154 `new_with_semiring_and_classifiers` / `new_with_semiring`) — \
5155 this legacy ctor defaults the semiring to AddMultProb and \
5156 ships no classifier registry, which the Phase B+ runtime needs \
5157 explicitly. To be removed after C0 Stage 2."
5158 )]
5159 pub fn new(
5160 rules: Vec<FixpointRulePlan>,
5161 max_iterations: usize,
5162 timeout: Duration,
5163 graph_ctx: Arc<GraphExecutionContext>,
5164 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5165 storage: Arc<StorageManager>,
5166 schema_info: Arc<UniSchema>,
5167 params: HashMap<String, Value>,
5168 derived_scan_registry: Arc<DerivedScanRegistry>,
5169 output_schema: SchemaRef,
5170 max_derived_bytes: usize,
5171 derivation_tracker: Option<Arc<ProvenanceStore>>,
5172 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5173 strict_probability_domain: bool,
5174 probability_epsilon: f64,
5175 exact_probability: bool,
5176 max_bdd_variables: usize,
5177 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5178 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5179 top_k_proofs: usize,
5180 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5181 ) -> Self {
5182 Self::new_with_semiring_and_classifiers(
5183 rules,
5184 max_iterations,
5185 timeout,
5186 graph_ctx,
5187 session_ctx,
5188 storage,
5189 schema_info,
5190 params,
5191 derived_scan_registry,
5192 output_schema,
5193 max_derived_bytes,
5194 derivation_tracker,
5195 iteration_counts,
5196 strict_probability_domain,
5197 probability_epsilon,
5198 exact_probability,
5199 max_bdd_variables,
5200 warnings_slot,
5201 approximate_slot,
5202 top_k_proofs,
5203 timeout_flag,
5204 SemiringKind::AddMultProb,
5205 Arc::new(ClassifierRegistry::new()),
5206 )
5207 }
5208
5209 #[expect(
5213 clippy::too_many_arguments,
5214 reason = "FixpointExec configuration needs all context"
5215 )]
5216 pub fn new_with_semiring(
5217 rules: Vec<FixpointRulePlan>,
5218 max_iterations: usize,
5219 timeout: Duration,
5220 graph_ctx: Arc<GraphExecutionContext>,
5221 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5222 storage: Arc<StorageManager>,
5223 schema_info: Arc<UniSchema>,
5224 params: HashMap<String, Value>,
5225 derived_scan_registry: Arc<DerivedScanRegistry>,
5226 output_schema: SchemaRef,
5227 max_derived_bytes: usize,
5228 derivation_tracker: Option<Arc<ProvenanceStore>>,
5229 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5230 strict_probability_domain: bool,
5231 probability_epsilon: f64,
5232 exact_probability: bool,
5233 max_bdd_variables: usize,
5234 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5235 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5236 top_k_proofs: usize,
5237 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5238 semiring_kind: SemiringKind,
5239 ) -> Self {
5240 Self::new_with_semiring_and_classifiers(
5241 rules,
5242 max_iterations,
5243 timeout,
5244 graph_ctx,
5245 session_ctx,
5246 storage,
5247 schema_info,
5248 params,
5249 derived_scan_registry,
5250 output_schema,
5251 max_derived_bytes,
5252 derivation_tracker,
5253 iteration_counts,
5254 strict_probability_domain,
5255 probability_epsilon,
5256 exact_probability,
5257 max_bdd_variables,
5258 warnings_slot,
5259 approximate_slot,
5260 top_k_proofs,
5261 timeout_flag,
5262 semiring_kind,
5263 Arc::new(ClassifierRegistry::new()),
5264 )
5265 }
5266
5267 #[expect(
5271 clippy::too_many_arguments,
5272 reason = "FixpointExec configuration needs all context"
5273 )]
5274 pub fn new_with_semiring_and_classifiers(
5275 rules: Vec<FixpointRulePlan>,
5276 max_iterations: usize,
5277 timeout: Duration,
5278 graph_ctx: Arc<GraphExecutionContext>,
5279 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5280 storage: Arc<StorageManager>,
5281 schema_info: Arc<UniSchema>,
5282 params: HashMap<String, Value>,
5283 derived_scan_registry: Arc<DerivedScanRegistry>,
5284 output_schema: SchemaRef,
5285 max_derived_bytes: usize,
5286 derivation_tracker: Option<Arc<ProvenanceStore>>,
5287 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5288 strict_probability_domain: bool,
5289 probability_epsilon: f64,
5290 exact_probability: bool,
5291 max_bdd_variables: usize,
5292 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5293 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5294 top_k_proofs: usize,
5295 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5296 semiring_kind: SemiringKind,
5297 classifier_registry: Arc<ClassifierRegistry>,
5298 ) -> Self {
5299 Self::new_with_semiring_classifiers_and_cache(
5300 rules,
5301 max_iterations,
5302 timeout,
5303 graph_ctx,
5304 session_ctx,
5305 storage,
5306 schema_info,
5307 params,
5308 derived_scan_registry,
5309 output_schema,
5310 max_derived_bytes,
5311 derivation_tracker,
5312 iteration_counts,
5313 strict_probability_domain,
5314 probability_epsilon,
5315 exact_probability,
5316 max_bdd_variables,
5317 warnings_slot,
5318 approximate_slot,
5319 top_k_proofs,
5320 timeout_flag,
5321 semiring_kind,
5322 classifier_registry,
5323 None,
5324 None,
5325 )
5326 }
5327
5328 #[expect(
5332 clippy::too_many_arguments,
5333 reason = "FixpointExec configuration needs all context"
5334 )]
5335 pub fn new_with_semiring_classifiers_and_cache(
5336 rules: Vec<FixpointRulePlan>,
5337 max_iterations: usize,
5338 timeout: Duration,
5339 graph_ctx: Arc<GraphExecutionContext>,
5340 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5341 storage: Arc<StorageManager>,
5342 schema_info: Arc<UniSchema>,
5343 params: HashMap<String, Value>,
5344 derived_scan_registry: Arc<DerivedScanRegistry>,
5345 output_schema: SchemaRef,
5346 max_derived_bytes: usize,
5347 derivation_tracker: Option<Arc<ProvenanceStore>>,
5348 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5349 strict_probability_domain: bool,
5350 probability_epsilon: f64,
5351 exact_probability: bool,
5352 max_bdd_variables: usize,
5353 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5354 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5355 top_k_proofs: usize,
5356 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5357 semiring_kind: SemiringKind,
5358 classifier_registry: Arc<ClassifierRegistry>,
5359 classifier_cache: Option<Arc<ModelInvocationCache>>,
5360 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
5361 ) -> Self {
5362 let properties = compute_plan_properties(Arc::clone(&output_schema));
5363 Self {
5364 rules,
5365 max_iterations,
5366 timeout,
5367 graph_ctx,
5368 session_ctx,
5369 storage,
5370 schema_info,
5371 params,
5372 derived_scan_registry,
5373 output_schema,
5374 properties,
5375 metrics: ExecutionPlanMetricsSet::new(),
5376 max_derived_bytes,
5377 derivation_tracker,
5378 iteration_counts,
5379 strict_probability_domain,
5380 probability_epsilon,
5381 exact_probability,
5382 max_bdd_variables,
5383 warnings_slot,
5384 approximate_slot,
5385 top_k_proofs,
5386 timeout_flag,
5387 semiring_kind,
5388 classifier_registry,
5389 classifier_cache,
5390 classifier_provenance_store,
5391 }
5392 }
5393
5394 pub fn iteration_counts(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
5396 Arc::clone(&self.iteration_counts)
5397 }
5398}
5399
5400impl DisplayAs for FixpointExec {
5401 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5402 write!(
5403 f,
5404 "FixpointExec: rules=[{}], max_iter={}, timeout={:?}",
5405 self.rules
5406 .iter()
5407 .map(|r| r.name.as_str())
5408 .collect::<Vec<_>>()
5409 .join(", "),
5410 self.max_iterations,
5411 self.timeout,
5412 )
5413 }
5414}
5415
5416impl ExecutionPlan for FixpointExec {
5417 fn name(&self) -> &str {
5418 "FixpointExec"
5419 }
5420
5421 fn as_any(&self) -> &dyn Any {
5422 self
5423 }
5424
5425 fn schema(&self) -> SchemaRef {
5426 Arc::clone(&self.output_schema)
5427 }
5428
5429 fn properties(&self) -> &Arc<PlanProperties> {
5430 &self.properties
5431 }
5432
5433 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
5434 vec![]
5436 }
5437
5438 fn with_new_children(
5439 self: Arc<Self>,
5440 children: Vec<Arc<dyn ExecutionPlan>>,
5441 ) -> DFResult<Arc<dyn ExecutionPlan>> {
5442 if !children.is_empty() {
5443 return Err(datafusion::error::DataFusionError::Plan(
5444 "FixpointExec has no children".to_string(),
5445 ));
5446 }
5447 Ok(self)
5448 }
5449
5450 fn execute(
5451 &self,
5452 partition: usize,
5453 _context: Arc<TaskContext>,
5454 ) -> DFResult<SendableRecordBatchStream> {
5455 let metrics = BaselineMetrics::new(&self.metrics, partition);
5456
5457 let rules = self
5459 .rules
5460 .iter()
5461 .map(|r| {
5462 FixpointRulePlan {
5466 name: r.name.clone(),
5467 clauses: r
5468 .clauses
5469 .iter()
5470 .map(|c| FixpointClausePlan {
5471 body_logical: c.body_logical.clone(),
5472 is_ref_bindings: c.is_ref_bindings.clone(),
5473 priority: c.priority,
5474 along_bindings: c.along_bindings.clone(),
5475 model_invocations: c.model_invocations.clone(),
5476 })
5477 .collect(),
5478 yield_schema: Arc::clone(&r.yield_schema),
5479 key_column_indices: r.key_column_indices.clone(),
5480 priority: r.priority,
5481 has_fold: r.has_fold,
5482 fold_bindings: r.fold_bindings.clone(),
5483 having: r.having.clone(),
5484 has_best_by: r.has_best_by,
5485 best_by_criteria: r.best_by_criteria.clone(),
5486 has_priority: r.has_priority,
5487 deterministic: r.deterministic,
5488 prob_column_name: r.prob_column_name.clone(),
5489 }
5490 })
5491 .collect();
5492
5493 let max_iterations = self.max_iterations;
5494 let timeout = self.timeout;
5495 let graph_ctx = Arc::clone(&self.graph_ctx);
5496 let session_ctx = Arc::clone(&self.session_ctx);
5497 let storage = Arc::clone(&self.storage);
5498 let schema_info = Arc::clone(&self.schema_info);
5499 let params = self.params.clone();
5500 let registry = Arc::clone(&self.derived_scan_registry);
5501 let output_schema = Arc::clone(&self.output_schema);
5502 let max_derived_bytes = self.max_derived_bytes;
5503 let derivation_tracker = self.derivation_tracker.clone();
5504 let iteration_counts = Arc::clone(&self.iteration_counts);
5505 let strict_probability_domain = self.strict_probability_domain;
5506 let probability_epsilon = self.probability_epsilon;
5507 let exact_probability = self.exact_probability;
5508 let max_bdd_variables = self.max_bdd_variables;
5509 let warnings_slot = Arc::clone(&self.warnings_slot);
5510 let approximate_slot = Arc::clone(&self.approximate_slot);
5511 let top_k_proofs = self.top_k_proofs;
5512 let timeout_flag = Arc::clone(&self.timeout_flag);
5513 let semiring_kind = self.semiring_kind;
5514 let classifier_registry = Arc::clone(&self.classifier_registry);
5515 let classifier_cache = self.classifier_cache.as_ref().map(Arc::clone);
5516 let classifier_provenance_store = self.classifier_provenance_store.as_ref().map(Arc::clone);
5517
5518 let fut = async move {
5519 run_fixpoint_loop(
5520 rules,
5521 max_iterations,
5522 timeout,
5523 graph_ctx,
5524 session_ctx,
5525 storage,
5526 schema_info,
5527 params,
5528 registry,
5529 output_schema,
5530 max_derived_bytes,
5531 derivation_tracker,
5532 iteration_counts,
5533 strict_probability_domain,
5534 probability_epsilon,
5535 exact_probability,
5536 max_bdd_variables,
5537 warnings_slot,
5538 approximate_slot,
5539 top_k_proofs,
5540 timeout_flag,
5541 semiring_kind,
5542 classifier_registry,
5543 classifier_cache,
5544 classifier_provenance_store,
5545 )
5546 .await
5547 };
5548
5549 Ok(Box::pin(FixpointStream {
5550 state: FixpointStreamState::Running(Box::pin(fut)),
5551 schema: Arc::clone(&self.output_schema),
5552 metrics,
5553 }))
5554 }
5555
5556 fn metrics(&self) -> Option<MetricsSet> {
5557 Some(self.metrics.clone_inner())
5558 }
5559}
5560
5561enum FixpointStreamState {
5566 Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
5568 Emitting(Vec<RecordBatch>, usize),
5570 Done,
5572}
5573
5574struct FixpointStream {
5575 state: FixpointStreamState,
5576 schema: SchemaRef,
5577 metrics: BaselineMetrics,
5578}
5579
5580impl Stream for FixpointStream {
5581 type Item = DFResult<RecordBatch>;
5582
5583 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
5584 let this = self.get_mut();
5585 let metrics = this.metrics.clone();
5586 let _timer = metrics.elapsed_compute().timer();
5587 loop {
5588 match &mut this.state {
5589 FixpointStreamState::Running(fut) => match fut.as_mut().poll(cx) {
5590 Poll::Ready(Ok(batches)) => {
5591 if batches.is_empty() {
5592 this.state = FixpointStreamState::Done;
5593 return Poll::Ready(None);
5594 }
5595 this.state = FixpointStreamState::Emitting(batches, 0);
5596 }
5598 Poll::Ready(Err(e)) => {
5599 this.state = FixpointStreamState::Done;
5600 return Poll::Ready(Some(Err(e)));
5601 }
5602 Poll::Pending => return Poll::Pending,
5603 },
5604 FixpointStreamState::Emitting(batches, idx) => {
5605 if *idx >= batches.len() {
5606 this.state = FixpointStreamState::Done;
5607 return Poll::Ready(None);
5608 }
5609 let batch = batches[*idx].clone();
5610 *idx += 1;
5611 this.metrics.record_output(batch.num_rows());
5612 return Poll::Ready(Some(Ok(batch)));
5613 }
5614 FixpointStreamState::Done => return Poll::Ready(None),
5615 }
5616 }
5617 }
5618}
5619
5620impl RecordBatchStream for FixpointStream {
5621 fn schema(&self) -> SchemaRef {
5622 Arc::clone(&self.schema)
5623 }
5624}
5625
5626#[cfg(test)]
5631mod tests {
5632 use super::*;
5633 use arrow_array::{Float64Array, Int64Array, StringArray};
5634 use arrow_schema::{DataType, Field, Schema};
5635
5636 fn test_schema() -> SchemaRef {
5637 Arc::new(Schema::new(vec![
5638 Field::new("name", DataType::Utf8, true),
5639 Field::new("value", DataType::Int64, true),
5640 ]))
5641 }
5642
5643 fn make_batch(names: &[&str], values: &[i64]) -> RecordBatch {
5644 RecordBatch::try_new(
5645 test_schema(),
5646 vec![
5647 Arc::new(StringArray::from(
5648 names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
5649 )),
5650 Arc::new(Int64Array::from(values.to_vec())),
5651 ],
5652 )
5653 .unwrap()
5654 }
5655
5656 #[tokio::test]
5659 async fn test_fixpoint_state_empty_facts_adds_all() {
5660 let schema = test_schema();
5661 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5662
5663 let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
5664 let changed = state.merge_delta(vec![batch], None).await.unwrap();
5665
5666 assert!(changed);
5667 assert_eq!(state.all_facts().len(), 1);
5668 assert_eq!(state.all_facts()[0].num_rows(), 3);
5669 assert_eq!(state.all_delta().len(), 1);
5670 assert_eq!(state.all_delta()[0].num_rows(), 3);
5671 }
5672
5673 #[tokio::test]
5674 async fn test_fixpoint_state_exact_duplicates_excluded() {
5675 let schema = test_schema();
5676 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5677
5678 let batch1 = make_batch(&["a", "b"], &[1, 2]);
5679 state.merge_delta(vec![batch1], None).await.unwrap();
5680
5681 let batch2 = make_batch(&["a", "b"], &[1, 2]);
5683 let changed = state.merge_delta(vec![batch2], None).await.unwrap();
5684 assert!(!changed);
5685 assert!(
5686 state.all_delta().is_empty() || state.all_delta().iter().all(|b| b.num_rows() == 0)
5687 );
5688 }
5689
5690 #[tokio::test]
5691 async fn test_fixpoint_state_partial_overlap() {
5692 let schema = test_schema();
5693 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5694
5695 let batch1 = make_batch(&["a", "b"], &[1, 2]);
5696 state.merge_delta(vec![batch1], None).await.unwrap();
5697
5698 let batch2 = make_batch(&["a", "c"], &[1, 3]);
5700 let changed = state.merge_delta(vec![batch2], None).await.unwrap();
5701 assert!(changed);
5702
5703 let delta_rows: usize = state.all_delta().iter().map(|b| b.num_rows()).sum();
5705 assert_eq!(delta_rows, 1);
5706
5707 let total_rows: usize = state.all_facts().iter().map(|b| b.num_rows()).sum();
5709 assert_eq!(total_rows, 3);
5710 }
5711
5712 #[tokio::test]
5713 async fn test_fixpoint_state_convergence() {
5714 let schema = test_schema();
5715 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5716
5717 let batch = make_batch(&["a"], &[1]);
5718 state.merge_delta(vec![batch], None).await.unwrap();
5719
5720 let changed = state.merge_delta(vec![], None).await.unwrap();
5722 assert!(!changed);
5723 assert!(state.is_converged());
5724 }
5725
5726 #[test]
5729 fn test_row_dedup_persistent_across_calls() {
5730 let schema = test_schema();
5733 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5734
5735 let batch1 = make_batch(&["a", "b"], &[1, 2]);
5736 let delta1 = rd.compute_delta(&[batch1], &schema).unwrap();
5737 let rows1: usize = delta1.iter().map(|b| b.num_rows()).sum();
5739 assert_eq!(rows1, 2);
5740
5741 let batch2 = make_batch(&["a", "b"], &[1, 2]);
5743 let delta2 = rd.compute_delta(&[batch2], &schema).unwrap();
5744 let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
5745 assert_eq!(rows2, 0);
5746
5747 let batch3 = make_batch(&["a", "c"], &[1, 3]);
5749 let delta3 = rd.compute_delta(&[batch3], &schema).unwrap();
5750 let rows3: usize = delta3.iter().map(|b| b.num_rows()).sum();
5751 assert_eq!(rows3, 1);
5752 }
5753
5754 #[test]
5755 fn test_row_dedup_null_handling() {
5756 use arrow_array::StringArray;
5757 use arrow_schema::{DataType, Field, Schema};
5758
5759 let schema: SchemaRef = Arc::new(Schema::new(vec![
5760 Field::new("a", DataType::Utf8, true),
5761 Field::new("b", DataType::Int64, true),
5762 ]));
5763 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5764
5765 let batch_nulls = RecordBatch::try_new(
5767 Arc::clone(&schema),
5768 vec![
5769 Arc::new(StringArray::from(vec![None::<&str>, None::<&str>])),
5770 Arc::new(arrow_array::Int64Array::from(vec![1i64, 1i64])),
5771 ],
5772 )
5773 .unwrap();
5774 let delta = rd.compute_delta(&[batch_nulls], &schema).unwrap();
5775 let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
5776 assert_eq!(rows, 1, "two identical NULL rows should be deduped to one");
5777
5778 let batch_diff = RecordBatch::try_new(
5780 Arc::clone(&schema),
5781 vec![
5782 Arc::new(StringArray::from(vec![None::<&str>])),
5783 Arc::new(arrow_array::Int64Array::from(vec![2i64])),
5784 ],
5785 )
5786 .unwrap();
5787 let delta2 = rd.compute_delta(&[batch_diff], &schema).unwrap();
5788 let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
5789 assert_eq!(rows2, 1, "(NULL, 2) is distinct from (NULL, 1)");
5790 }
5791
5792 #[test]
5793 fn test_row_dedup_within_candidate_dedup() {
5794 let schema = test_schema();
5796 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5797
5798 let batch = make_batch(&["a", "a", "b"], &[1, 1, 2]);
5800 let delta = rd.compute_delta(&[batch], &schema).unwrap();
5801 let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
5802 assert_eq!(rows, 2, "within-batch dup should be collapsed: a:1, b:2");
5803 }
5804
5805 #[test]
5808 fn test_round_float_columns_near_duplicates() {
5809 let schema = Arc::new(Schema::new(vec![
5810 Field::new("name", DataType::Utf8, true),
5811 Field::new("dist", DataType::Float64, true),
5812 ]));
5813 let batch = RecordBatch::try_new(
5814 Arc::clone(&schema),
5815 vec![
5816 Arc::new(StringArray::from(vec![Some("a"), Some("a")])),
5817 Arc::new(Float64Array::from(vec![1.0000000000001, 1.0000000000002])),
5818 ],
5819 )
5820 .unwrap();
5821
5822 let rounded = round_float_columns(&[batch]);
5823 assert_eq!(rounded.len(), 1);
5824 let col = rounded[0]
5825 .column(1)
5826 .as_any()
5827 .downcast_ref::<Float64Array>()
5828 .unwrap();
5829 assert_eq!(col.value(0), col.value(1));
5831 }
5832
5833 #[test]
5836 fn test_registry_write_read_round_trip() {
5837 let schema = test_schema();
5838 let data = Arc::new(RwLock::new(Vec::new()));
5839 let mut reg = DerivedScanRegistry::new();
5840 reg.add(DerivedScanEntry {
5841 scan_index: 0,
5842 rule_name: "reachable".into(),
5843 is_self_ref: true,
5844 data: Arc::clone(&data),
5845 schema: Arc::clone(&schema),
5846 });
5847
5848 let batch = make_batch(&["x"], &[42]);
5849 reg.write_data(0, vec![batch.clone()]);
5850
5851 let entry = reg.get(0).unwrap();
5852 let guard = entry.data.read();
5853 assert_eq!(guard.len(), 1);
5854 assert_eq!(guard[0].num_rows(), 1);
5855 }
5856
5857 #[test]
5858 fn test_registry_entries_for_rule() {
5859 let schema = test_schema();
5860 let mut reg = DerivedScanRegistry::new();
5861 reg.add(DerivedScanEntry {
5862 scan_index: 0,
5863 rule_name: "r1".into(),
5864 is_self_ref: true,
5865 data: Arc::new(RwLock::new(Vec::new())),
5866 schema: Arc::clone(&schema),
5867 });
5868 reg.add(DerivedScanEntry {
5869 scan_index: 1,
5870 rule_name: "r2".into(),
5871 is_self_ref: false,
5872 data: Arc::new(RwLock::new(Vec::new())),
5873 schema: Arc::clone(&schema),
5874 });
5875 reg.add(DerivedScanEntry {
5876 scan_index: 2,
5877 rule_name: "r1".into(),
5878 is_self_ref: false,
5879 data: Arc::new(RwLock::new(Vec::new())),
5880 schema: Arc::clone(&schema),
5881 });
5882
5883 assert_eq!(reg.entries_for_rule("r1").len(), 2);
5884 assert_eq!(reg.entries_for_rule("r2").len(), 1);
5885 assert_eq!(reg.entries_for_rule("r3").len(), 0);
5886 }
5887
5888 #[test]
5891 fn test_monotonic_agg_update_and_stability() {
5892 let bindings = vec![MonotonicFoldBinding {
5893 fold_name: "total".into(),
5894 aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::SumAgg),
5895 input_col_index: 1,
5896 input_col_name: None,
5897 }];
5898 let mut agg = MonotonicAggState::new(bindings);
5899
5900 let batch = make_batch(&["a"], &[10]);
5902 agg.snapshot();
5903 let changed = agg
5904 .update(&[0], &[batch], false, SemiringKind::AddMultProb)
5905 .unwrap();
5906 assert!(changed);
5907 assert!(!agg.is_stable()); agg.snapshot();
5911 let changed = agg
5912 .update(&[0], &[], false, SemiringKind::AddMultProb)
5913 .unwrap();
5914 assert!(!changed);
5915 assert!(agg.is_stable());
5916 }
5917
5918 #[tokio::test]
5921 async fn test_memory_limit_exceeded() {
5922 let schema = test_schema();
5923 let mut state = FixpointState::new("test".into(), schema, vec![0], 1, None, false);
5925
5926 let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
5927 let result = state.merge_delta(vec![batch], None).await;
5928 assert!(result.is_err());
5929 let err = result.unwrap_err().to_string();
5930 assert!(err.contains("memory limit"), "Error was: {}", err);
5931 }
5932
5933 #[tokio::test]
5936 async fn test_fixpoint_stream_emitting() {
5937 use futures::StreamExt;
5938
5939 let schema = test_schema();
5940 let batch1 = make_batch(&["a"], &[1]);
5941 let batch2 = make_batch(&["b"], &[2]);
5942
5943 let metrics = ExecutionPlanMetricsSet::new();
5944 let baseline = BaselineMetrics::new(&metrics, 0);
5945
5946 let mut stream = FixpointStream {
5947 state: FixpointStreamState::Emitting(vec![batch1, batch2], 0),
5948 schema,
5949 metrics: baseline,
5950 };
5951
5952 let stream = Pin::new(&mut stream);
5953 let batches: Vec<RecordBatch> = stream.filter_map(|r| async { r.ok() }).collect().await;
5954
5955 assert_eq!(batches.len(), 2);
5956 assert_eq!(batches[0].num_rows(), 1);
5957 assert_eq!(batches[1].num_rows(), 1);
5958 }
5959
5960 fn make_f64_batch(names: &[&str], values: &[f64]) -> RecordBatch {
5963 let schema = Arc::new(Schema::new(vec![
5964 Field::new("name", DataType::Utf8, true),
5965 Field::new("value", DataType::Float64, true),
5966 ]));
5967 RecordBatch::try_new(
5968 schema,
5969 vec![
5970 Arc::new(StringArray::from(
5971 names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
5972 )),
5973 Arc::new(Float64Array::from(values.to_vec())),
5974 ],
5975 )
5976 .unwrap()
5977 }
5978
5979 fn make_nor_binding() -> Vec<MonotonicFoldBinding> {
5980 vec![MonotonicFoldBinding {
5981 fold_name: "prob".into(),
5982 aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::MnorAgg),
5983 input_col_index: 1,
5984 input_col_name: None,
5985 }]
5986 }
5987
5988 fn make_prod_binding() -> Vec<MonotonicFoldBinding> {
5989 vec![MonotonicFoldBinding {
5990 fold_name: "prob".into(),
5991 aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::MprodAgg),
5992 input_col_index: 1,
5993 input_col_name: None,
5994 }]
5995 }
5996
5997 fn acc_key(name: &str) -> (Vec<ScalarKey>, String) {
5998 (vec![ScalarKey::Utf8(name.to_string())], "prob".to_string())
5999 }
6000
6001 #[test]
6002 fn test_monotonic_nor_first_update() {
6003 let mut agg = MonotonicAggState::new(make_nor_binding());
6004 let batch = make_f64_batch(&["a"], &[0.3]);
6005 let changed = agg
6006 .update(&[0], &[batch], false, SemiringKind::AddMultProb)
6007 .unwrap();
6008 assert!(changed);
6009 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6010 assert!((val - 0.3).abs() < 1e-10, "expected 0.3, got {}", val);
6011 }
6012
6013 #[test]
6014 fn test_monotonic_nor_two_updates() {
6015 let mut agg = MonotonicAggState::new(make_nor_binding());
6017 let batch1 = make_f64_batch(&["a"], &[0.3]);
6018 agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6019 .unwrap();
6020 let batch2 = make_f64_batch(&["a"], &[0.5]);
6021 agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6022 .unwrap();
6023 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6024 assert!((val - 0.65).abs() < 1e-10, "expected 0.65, got {}", val);
6025 }
6026
6027 #[test]
6028 fn test_monotonic_prod_first_update() {
6029 let mut agg = MonotonicAggState::new(make_prod_binding());
6030 let batch = make_f64_batch(&["a"], &[0.6]);
6031 let changed = agg
6032 .update(&[0], &[batch], false, SemiringKind::AddMultProb)
6033 .unwrap();
6034 assert!(changed);
6035 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6036 assert!((val - 0.6).abs() < 1e-10, "expected 0.6, got {}", val);
6037 }
6038
6039 #[test]
6040 fn test_monotonic_prod_two_updates() {
6041 let mut agg = MonotonicAggState::new(make_prod_binding());
6043 let batch1 = make_f64_batch(&["a"], &[0.6]);
6044 agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6045 .unwrap();
6046 let batch2 = make_f64_batch(&["a"], &[0.8]);
6047 agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6048 .unwrap();
6049 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6050 assert!((val - 0.48).abs() < 1e-10, "expected 0.48, got {}", val);
6051 }
6052
6053 #[test]
6054 fn test_monotonic_nor_stability() {
6055 let mut agg = MonotonicAggState::new(make_nor_binding());
6056 let batch = make_f64_batch(&["a"], &[0.3]);
6057 agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6058 .unwrap();
6059 agg.snapshot();
6060 let changed = agg
6061 .update(&[0], &[], false, SemiringKind::AddMultProb)
6062 .unwrap();
6063 assert!(!changed);
6064 assert!(agg.is_stable());
6065 }
6066
6067 #[test]
6068 fn test_monotonic_prod_stability() {
6069 let mut agg = MonotonicAggState::new(make_prod_binding());
6070 let batch = make_f64_batch(&["a"], &[0.6]);
6071 agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6072 .unwrap();
6073 agg.snapshot();
6074 let changed = agg
6075 .update(&[0], &[], false, SemiringKind::AddMultProb)
6076 .unwrap();
6077 assert!(!changed);
6078 assert!(agg.is_stable());
6079 }
6080
6081 #[test]
6082 fn test_monotonic_nor_multi_group() {
6083 let mut agg = MonotonicAggState::new(make_nor_binding());
6085 let batch1 = make_f64_batch(&["a", "b"], &[0.3, 0.5]);
6086 agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6087 .unwrap();
6088 let batch2 = make_f64_batch(&["a", "b"], &[0.5, 0.2]);
6089 agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6090 .unwrap();
6091
6092 let val_a = agg.get_accumulator(&acc_key("a")).unwrap();
6093 let val_b = agg.get_accumulator(&acc_key("b")).unwrap();
6094 assert!(
6095 (val_a - 0.65).abs() < 1e-10,
6096 "expected a=0.65, got {}",
6097 val_a
6098 );
6099 assert!((val_b - 0.6).abs() < 1e-10, "expected b=0.6, got {}", val_b);
6100 }
6101
6102 #[test]
6103 fn test_monotonic_prod_zero_absorbing() {
6104 let mut agg = MonotonicAggState::new(make_prod_binding());
6106 let batch1 = make_f64_batch(&["a"], &[0.5]);
6107 agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6108 .unwrap();
6109 let batch2 = make_f64_batch(&["a"], &[0.0]);
6110 agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6111 .unwrap();
6112
6113 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6114 assert!((val - 0.0).abs() < 1e-10, "expected 0.0, got {}", val);
6115
6116 agg.snapshot();
6118 let batch3 = make_f64_batch(&["a"], &[0.5]);
6119 let changed = agg
6120 .update(&[0], &[batch3], false, SemiringKind::AddMultProb)
6121 .unwrap();
6122 assert!(!changed);
6123 assert!(agg.is_stable());
6124 }
6125
6126 #[test]
6127 fn test_monotonic_nor_clamping() {
6128 let mut agg = MonotonicAggState::new(make_nor_binding());
6130 let batch = make_f64_batch(&["a"], &[1.5]);
6131 agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6132 .unwrap();
6133 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6134 assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
6135 }
6136
6137 #[test]
6138 fn test_monotonic_nor_absorbing() {
6139 let mut agg = MonotonicAggState::new(make_nor_binding());
6141 let batch1 = make_f64_batch(&["a"], &[0.3]);
6142 agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6143 .unwrap();
6144 let batch2 = make_f64_batch(&["a"], &[1.0]);
6145 agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6146 .unwrap();
6147 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6148 assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
6149 }
6150
6151 #[test]
6154 fn test_monotonic_agg_strict_nor_rejects() {
6155 let mut agg = MonotonicAggState::new(make_nor_binding());
6156 let batch = make_f64_batch(&["a"], &[1.5]);
6157 let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6158 assert!(result.is_err());
6159 let err = result.unwrap_err().to_string();
6160 assert!(
6161 err.contains("strict_probability_domain"),
6162 "Expected strict error, got: {}",
6163 err
6164 );
6165 }
6166
6167 #[test]
6168 fn test_monotonic_agg_strict_prod_rejects() {
6169 let mut agg = MonotonicAggState::new(make_prod_binding());
6170 let batch = make_f64_batch(&["a"], &[2.0]);
6171 let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6172 assert!(result.is_err());
6173 let err = result.unwrap_err().to_string();
6174 assert!(
6175 err.contains("strict_probability_domain"),
6176 "Expected strict error, got: {}",
6177 err
6178 );
6179 }
6180
6181 #[test]
6182 fn test_monotonic_agg_strict_accepts_valid() {
6183 let mut agg = MonotonicAggState::new(make_nor_binding());
6184 let batch = make_f64_batch(&["a"], &[0.5]);
6185 let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6186 assert!(result.is_ok());
6187 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6188 assert!((val - 0.5).abs() < 1e-10, "expected 0.5, got {}", val);
6189 }
6190
6191 fn make_vid_prob_batch(vids: &[u64], probs: &[f64]) -> RecordBatch {
6194 use arrow_array::UInt64Array;
6195 let schema = Arc::new(Schema::new(vec![
6196 Field::new("vid", DataType::UInt64, true),
6197 Field::new("prob", DataType::Float64, true),
6198 ]));
6199 RecordBatch::try_new(
6200 schema,
6201 vec![
6202 Arc::new(UInt64Array::from(vids.to_vec())),
6203 Arc::new(Float64Array::from(probs.to_vec())),
6204 ],
6205 )
6206 .unwrap()
6207 }
6208
6209 #[test]
6210 fn test_prob_complement_basic() {
6211 let body = make_vid_prob_batch(&[1, 2], &[0.9, 0.8]);
6213 let neg = make_vid_prob_batch(&[1], &[0.7]);
6214 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6215 let result = apply_prob_complement_composite(
6216 vec![body],
6217 &[neg],
6218 &join_cols,
6219 "prob",
6220 "__complement_0",
6221 )
6222 .unwrap();
6223 assert_eq!(result.len(), 1);
6224 let batch = &result[0];
6225 let complement = batch
6226 .column_by_name("__complement_0")
6227 .unwrap()
6228 .as_any()
6229 .downcast_ref::<Float64Array>()
6230 .unwrap();
6231 assert!(
6233 (complement.value(0) - 0.3).abs() < 1e-10,
6234 "expected 0.3, got {}",
6235 complement.value(0)
6236 );
6237 assert!(
6239 (complement.value(1) - 1.0).abs() < 1e-10,
6240 "expected 1.0, got {}",
6241 complement.value(1)
6242 );
6243 }
6244
6245 #[test]
6246 fn test_prob_complement_noisy_or_duplicates() {
6247 let body = make_vid_prob_batch(&[1], &[0.9]);
6251 let neg = make_vid_prob_batch(&[1, 1], &[0.3, 0.5]);
6252 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6253 let result = apply_prob_complement_composite(
6254 vec![body],
6255 &[neg],
6256 &join_cols,
6257 "prob",
6258 "__complement_0",
6259 )
6260 .unwrap();
6261 let batch = &result[0];
6262 let complement = batch
6263 .column_by_name("__complement_0")
6264 .unwrap()
6265 .as_any()
6266 .downcast_ref::<Float64Array>()
6267 .unwrap();
6268 assert!(
6269 (complement.value(0) - 0.35).abs() < 1e-10,
6270 "expected 0.35, got {}",
6271 complement.value(0)
6272 );
6273 }
6274
6275 #[test]
6276 fn test_prob_complement_empty_neg() {
6277 let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
6279 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6280 let result =
6281 apply_prob_complement_composite(vec![body], &[], &join_cols, "prob", "__complement_0")
6282 .unwrap();
6283 let batch = &result[0];
6284 let complement = batch
6285 .column_by_name("__complement_0")
6286 .unwrap()
6287 .as_any()
6288 .downcast_ref::<Float64Array>()
6289 .unwrap();
6290 for i in 0..2 {
6291 assert!(
6292 (complement.value(i) - 1.0).abs() < 1e-10,
6293 "row {}: expected 1.0, got {}",
6294 i,
6295 complement.value(i)
6296 );
6297 }
6298 }
6299
6300 #[test]
6301 fn test_anti_join_basic() {
6302 use arrow_array::UInt64Array;
6304 let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
6305 let neg = make_vid_prob_batch(&[2], &[0.0]);
6306 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6307 let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
6308 assert_eq!(result.len(), 1);
6309 let batch = &result[0];
6310 assert_eq!(batch.num_rows(), 2);
6311 let vids = batch
6312 .column_by_name("vid")
6313 .unwrap()
6314 .as_any()
6315 .downcast_ref::<UInt64Array>()
6316 .unwrap();
6317 assert_eq!(vids.value(0), 1);
6318 assert_eq!(vids.value(1), 3);
6319 }
6320
6321 #[test]
6322 fn test_anti_join_empty_neg() {
6323 let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
6325 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6326 let result = apply_anti_join_composite(vec![body], &[], &join_cols).unwrap();
6327 assert_eq!(result.len(), 1);
6328 assert_eq!(result[0].num_rows(), 3);
6329 }
6330
6331 #[test]
6332 fn test_anti_join_all_excluded() {
6333 let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
6335 let neg = make_vid_prob_batch(&[1, 2], &[0.0, 0.0]);
6336 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6337 let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
6338 let total: usize = result.iter().map(|b| b.num_rows()).sum();
6339 assert_eq!(total, 0);
6340 }
6341
6342 #[test]
6343 fn test_multiply_prob_single_complement() {
6344 let body = make_vid_prob_batch(&[1], &[0.8]);
6346 let complement_arr = Float64Array::from(vec![0.5]);
6348 let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
6349 cols.push(Arc::new(complement_arr));
6350 let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
6351 fields.push(Arc::new(Field::new(
6352 "__complement_0",
6353 DataType::Float64,
6354 true,
6355 )));
6356 let schema = Arc::new(Schema::new(fields));
6357 let batch = RecordBatch::try_new(schema, cols).unwrap();
6358
6359 let result =
6360 multiply_prob_factors(vec![batch], Some("prob"), &["__complement_0".to_string()])
6361 .unwrap();
6362 assert_eq!(result.len(), 1);
6363 let out = &result[0];
6364 assert!(out.column_by_name("__complement_0").is_none());
6366 let prob = out
6367 .column_by_name("prob")
6368 .unwrap()
6369 .as_any()
6370 .downcast_ref::<Float64Array>()
6371 .unwrap();
6372 assert!(
6373 (prob.value(0) - 0.4).abs() < 1e-10,
6374 "expected 0.4, got {}",
6375 prob.value(0)
6376 );
6377 }
6378
6379 #[test]
6380 fn test_multiply_prob_multiple_complements() {
6381 let body = make_vid_prob_batch(&[1], &[0.8]);
6383 let c1 = Float64Array::from(vec![0.5]);
6384 let c2 = Float64Array::from(vec![0.6]);
6385 let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
6386 cols.push(Arc::new(c1));
6387 cols.push(Arc::new(c2));
6388 let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
6389 fields.push(Arc::new(Field::new("__c1", DataType::Float64, true)));
6390 fields.push(Arc::new(Field::new("__c2", DataType::Float64, true)));
6391 let schema = Arc::new(Schema::new(fields));
6392 let batch = RecordBatch::try_new(schema, cols).unwrap();
6393
6394 let result = multiply_prob_factors(
6395 vec![batch],
6396 Some("prob"),
6397 &["__c1".to_string(), "__c2".to_string()],
6398 )
6399 .unwrap();
6400 let out = &result[0];
6401 assert!(out.column_by_name("__c1").is_none());
6402 assert!(out.column_by_name("__c2").is_none());
6403 let prob = out
6404 .column_by_name("prob")
6405 .unwrap()
6406 .as_any()
6407 .downcast_ref::<Float64Array>()
6408 .unwrap();
6409 assert!(
6410 (prob.value(0) - 0.24).abs() < 1e-10,
6411 "expected 0.24, got {}",
6412 prob.value(0)
6413 );
6414 }
6415
6416 #[test]
6417 fn test_multiply_prob_no_prob_column() {
6418 use arrow_array::UInt64Array;
6420 let schema = Arc::new(Schema::new(vec![
6421 Field::new("vid", DataType::UInt64, true),
6422 Field::new("__c1", DataType::Float64, true),
6423 ]));
6424 let batch = RecordBatch::try_new(
6425 schema,
6426 vec![
6427 Arc::new(UInt64Array::from(vec![1u64])),
6428 Arc::new(Float64Array::from(vec![0.7])),
6429 ],
6430 )
6431 .unwrap();
6432
6433 let result = multiply_prob_factors(vec![batch], None, &["__c1".to_string()]).unwrap();
6434 let out = &result[0];
6435 assert!(out.column_by_name("__c1").is_none());
6437 assert_eq!(out.num_columns(), 1);
6439 }
6440}