1use crate::query::df_graph::GraphExecutionContext;
10use crate::query::df_graph::common::{
11 ScalarKey, arrow_err, collect_all_partitions, compute_plan_properties, execute_subplan,
12 extract_scalar_key,
13};
14use crate::query::df_graph::locy_best_by::{BestByExec, SortCriterion};
15use crate::query::df_graph::locy_errors::LocyRuntimeError;
16use crate::query::df_graph::locy_explain::{
17 ProofTerm, ProvenanceAnnotation, ProvenanceStore, compute_proof_probability,
18};
19use crate::query::df_graph::locy_fold::{FoldBinding, FoldExec};
20use crate::query::df_graph::locy_priority::PriorityExec;
21use crate::query::df_graph::locy_program::interruption;
22use crate::query::planner::LogicalPlan;
23use arrow_array::RecordBatch;
24use arrow_row::{RowConverter, SortField};
25use arrow_schema::SchemaRef;
26use datafusion::common::JoinType;
27use datafusion::common::Result as DFResult;
28use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
29use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
30use datafusion::physical_plan::memory::MemoryStream;
31use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
32use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
33use futures::Stream;
34use parking_lot::RwLock;
35use std::any::Any;
36use std::collections::{HashMap, HashSet};
37use std::fmt;
38use std::pin::Pin;
39use std::sync::{Arc, RwLock as StdRwLock};
40use std::task::{Context, Poll};
41use std::time::{Duration, Instant};
42use uni_common::Value;
43use uni_common::core::schema::Schema as UniSchema;
44use uni_cypher::ast::Expr;
45use uni_locy::{
46 ClassifierRegistry, ModelInvocation, ModelInvocationCache, RuntimeWarning, RuntimeWarningCode,
47 SemiringKind,
48};
49use uni_store::storage::manager::StorageManager;
50
51#[derive(Debug)]
61pub struct DerivedScanEntry {
62 pub scan_index: usize,
64 pub rule_name: String,
66 pub is_self_ref: bool,
68 pub data: Arc<RwLock<Vec<RecordBatch>>>,
70 pub schema: SchemaRef,
72}
73
74#[derive(Debug, Default)]
81pub struct DerivedScanRegistry {
82 entries: Vec<DerivedScanEntry>,
83}
84
85impl DerivedScanRegistry {
86 pub fn new() -> Self {
88 Self::default()
89 }
90
91 pub fn add(&mut self, entry: DerivedScanEntry) {
93 self.entries.push(entry);
94 }
95
96 pub fn get(&self, scan_index: usize) -> Option<&DerivedScanEntry> {
98 self.entries.iter().find(|e| e.scan_index == scan_index)
99 }
100
101 pub fn write_data(&self, scan_index: usize, batches: Vec<RecordBatch>) {
103 if let Some(entry) = self.get(scan_index) {
104 let mut guard = entry.data.write();
105 *guard = batches;
106 }
107 }
108
109 pub fn entries_for_rule(&self, rule_name: &str) -> Vec<&DerivedScanEntry> {
111 self.entries
112 .iter()
113 .filter(|e| e.rule_name == rule_name)
114 .collect()
115 }
116}
117
118#[derive(Debug, Clone)]
129pub struct MonotonicFoldBinding {
130 pub fold_name: String,
131 pub aggregate: std::sync::Arc<dyn uni_plugin::traits::locy::LocyAggregate>,
132 pub input_col_index: usize,
133 pub input_col_name: Option<String>,
135}
136
137#[derive(Debug)]
143pub struct MonotonicAggState {
144 accumulators: HashMap<(Vec<ScalarKey>, String), f64>,
146 prev_snapshot: HashMap<(Vec<ScalarKey>, String), f64>,
148 bindings: Vec<MonotonicFoldBinding>,
150}
151
152impl MonotonicAggState {
153 pub fn new(bindings: Vec<MonotonicFoldBinding>) -> Self {
155 Self {
156 accumulators: HashMap::new(),
157 prev_snapshot: HashMap::new(),
158 bindings,
159 }
160 }
161
162 pub fn update(
188 &mut self,
189 key_indices: &[usize],
190 delta_batches: &[RecordBatch],
191 strict: bool,
192 semiring_kind: SemiringKind,
193 ) -> DFResult<bool> {
194 let mut changed = false;
195 for batch in delta_batches {
196 for row_idx in 0..batch.num_rows() {
197 let group_key = extract_scalar_key(batch, key_indices, row_idx);
198 for binding in &self.bindings {
199 let idx = binding
200 .input_col_name
201 .as_ref()
202 .and_then(|name| batch.schema().index_of(name).ok())
203 .unwrap_or(binding.input_col_index);
204 if idx >= batch.num_columns() {
205 continue;
206 }
207 let col = batch.column(idx);
208 let val = extract_f64(col.as_ref(), row_idx);
209 if let Some(val) = val {
210 let map_key = (group_key.clone(), binding.fold_name.clone());
211 let initial = binding.aggregate.initial_accum_f64().unwrap_or(0.0);
212 let entry = self.accumulators.entry(map_key).or_insert(initial);
213 let old = *entry;
214 if matches!(semiring_kind, SemiringKind::MaxMinProb)
225 && binding.aggregate.is_probability_aggregate()
226 {
227 use uni_locy::LocySemiring;
228 let sr = uni_locy::MaxMinProb;
229 let is_nor = binding.aggregate.is_noisy_or();
230 let label = if is_nor { "MNOR" } else { "MPROD" };
231 if strict && !(0.0..=1.0).contains(&val) {
232 return Err(datafusion::error::DataFusionError::Execution(
233 format!(
234 "strict_probability_domain: {label} input {val} is outside [0, 1]"
235 ),
236 ));
237 }
238 if !strict && !(0.0..=1.0).contains(&val) {
239 tracing::warn!(
240 "{label} input {val} outside [0,1], clamped to {}",
241 val.clamp(0.0, 1.0)
242 );
243 }
244 let p = val.clamp(0.0, 1.0);
245 *entry = if is_nor {
247 sr.plus(entry, &p)
248 } else {
249 sr.times(entry, &p)
250 };
251 if (*entry - old).abs() > f64::EPSILON {
252 changed = true;
253 }
254 continue;
255 }
256 match binding.aggregate.update_step(*entry, val, strict) {
257 Ok(new_val) => {
258 *entry = new_val;
259 if (*entry - old).abs() > f64::EPSILON {
260 changed = true;
261 }
262 }
263 Err(e) if e.code == uni_plugin::FnError::CODE_UNKNOWN_FUNCTION => {
264 }
268 Err(e) => {
269 return Err(datafusion::error::DataFusionError::Execution(
272 e.message,
273 ));
274 }
275 }
276 }
277 }
278 }
279 }
280 Ok(changed)
281 }
282
283 pub fn snapshot(&mut self) {
285 self.prev_snapshot = self.accumulators.clone();
286 }
287
288 pub fn is_stable(&self) -> bool {
290 if self.accumulators.len() != self.prev_snapshot.len() {
291 return false;
292 }
293 for (key, val) in &self.accumulators {
294 match self.prev_snapshot.get(key) {
295 Some(prev) if (*val - *prev).abs() <= f64::EPSILON => {}
296 _ => return false,
297 }
298 }
299 true
300 }
301
302 #[cfg(test)]
304 pub(crate) fn get_accumulator(&self, key: &(Vec<ScalarKey>, String)) -> Option<f64> {
305 self.accumulators.get(key).copied()
306 }
307}
308
309fn extract_f64(col: &dyn arrow_array::Array, row_idx: usize) -> Option<f64> {
311 if col.is_null(row_idx) {
312 return None;
313 }
314 if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Float64Array>() {
315 Some(arr.value(row_idx))
316 } else {
317 col.as_any()
318 .downcast_ref::<arrow_array::Int64Array>()
319 .map(|arr| arr.value(row_idx) as f64)
320 }
321}
322
323struct RowDedupState {
333 converter: RowConverter,
334 seen: HashSet<Box<[u8]>>,
335}
336
337impl RowDedupState {
338 fn try_new(schema: &SchemaRef) -> Option<Self> {
343 let fields: Vec<SortField> = schema
344 .fields()
345 .iter()
346 .map(|f| SortField::new(f.data_type().clone()))
347 .collect();
348 match RowConverter::new(fields) {
349 Ok(converter) => Some(Self {
350 converter,
351 seen: HashSet::new(),
352 }),
353 Err(e) => {
354 tracing::warn!(
355 "RowDedupState: RowConverter unsupported for schema, falling back to legacy dedup: {}",
356 e
357 );
358 None
359 }
360 }
361 }
362
363 fn ingest_existing(&mut self, facts: &[RecordBatch], _schema: &SchemaRef) {
368 self.seen.clear();
369 for batch in facts {
370 if batch.num_rows() == 0 {
371 continue;
372 }
373 let arrays: Vec<_> = batch.columns().to_vec();
374 if let Ok(rows) = self.converter.convert_columns(&arrays) {
375 for row_idx in 0..batch.num_rows() {
376 let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
377 self.seen.insert(row_bytes);
378 }
379 }
380 }
381 }
382
383 fn compute_delta(
389 &mut self,
390 candidates: &[RecordBatch],
391 schema: &SchemaRef,
392 ) -> DFResult<Vec<RecordBatch>> {
393 let mut delta_batches = Vec::new();
394 for batch in candidates {
395 if batch.num_rows() == 0 {
396 continue;
397 }
398
399 let arrays: Vec<_> = batch.columns().to_vec();
401 let rows = self.converter.convert_columns(&arrays).map_err(arrow_err)?;
402
403 let mut keep = Vec::with_capacity(batch.num_rows());
405 for row_idx in 0..batch.num_rows() {
406 let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
407 keep.push(self.seen.insert(row_bytes));
408 }
409
410 let keep_mask = arrow_array::BooleanArray::from(keep);
411 let new_cols = batch
412 .columns()
413 .iter()
414 .map(|col| {
415 arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
416 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
417 })
418 })
419 .collect::<DFResult<Vec<_>>>()?;
420
421 if new_cols.first().is_some_and(|c| !c.is_empty()) {
422 let filtered = RecordBatch::try_new(Arc::clone(schema), new_cols).map_err(|e| {
423 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
424 })?;
425 delta_batches.push(filtered);
426 }
427 }
428 Ok(delta_batches)
429 }
430}
431
432pub struct FixpointState {
442 rule_name: String,
443 facts: Vec<RecordBatch>,
444 delta: Vec<RecordBatch>,
445 schema: SchemaRef,
446 key_column_indices: Vec<usize>,
447 key_column_names: Vec<String>,
449 all_column_indices: Vec<usize>,
451 facts_bytes: usize,
453 max_derived_bytes: usize,
455 monotonic_agg: Option<MonotonicAggState>,
457 row_dedup: Option<RowDedupState>,
459 strict_probability_domain: bool,
461 semiring_kind: SemiringKind,
463}
464
465impl FixpointState {
466 pub fn new(
471 rule_name: String,
472 schema: SchemaRef,
473 key_column_indices: Vec<usize>,
474 max_derived_bytes: usize,
475 monotonic_agg: Option<MonotonicAggState>,
476 strict_probability_domain: bool,
477 ) -> Self {
478 Self::new_with_semiring(
479 rule_name,
480 schema,
481 key_column_indices,
482 max_derived_bytes,
483 monotonic_agg,
484 strict_probability_domain,
485 SemiringKind::AddMultProb,
486 )
487 }
488
489 pub fn new_with_semiring(
490 rule_name: String,
491 schema: SchemaRef,
492 key_column_indices: Vec<usize>,
493 max_derived_bytes: usize,
494 monotonic_agg: Option<MonotonicAggState>,
495 strict_probability_domain: bool,
496 semiring_kind: SemiringKind,
497 ) -> Self {
498 let num_cols = schema.fields().len();
499 let row_dedup = RowDedupState::try_new(&schema);
500 let key_column_names: Vec<String> = key_column_indices
501 .iter()
502 .filter_map(|&i| schema.fields().get(i).map(|f| f.name().clone()))
503 .collect();
504 Self {
505 rule_name,
506 facts: Vec::new(),
507 delta: Vec::new(),
508 schema,
509 key_column_indices,
510 key_column_names,
511 all_column_indices: (0..num_cols).collect(),
512 facts_bytes: 0,
513 max_derived_bytes,
514 monotonic_agg,
515 row_dedup,
516 strict_probability_domain,
517 semiring_kind,
518 }
519 }
520
521 fn reconcile_schema(&mut self, actual_schema: &SchemaRef) {
528 if self.schema.fields() != actual_schema.fields() {
529 tracing::debug!(
530 rule = %self.rule_name,
531 "Reconciling fixpoint schema from physical plan output",
532 );
533 self.schema = Arc::clone(actual_schema);
534 self.row_dedup = RowDedupState::try_new(&self.schema);
535 let new_indices: Vec<usize> = self
539 .key_column_names
540 .iter()
541 .filter_map(|name| actual_schema.index_of(name).ok())
542 .collect();
543 if new_indices.len() == self.key_column_names.len() {
544 self.key_column_indices = new_indices;
545 }
546 let num_cols = actual_schema.fields().len();
548 self.all_column_indices = (0..num_cols).collect();
549 }
550 }
551
552 pub async fn merge_delta(
556 &mut self,
557 candidates: Vec<RecordBatch>,
558 task_ctx: Option<Arc<TaskContext>>,
559 ) -> DFResult<bool> {
560 if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
561 self.delta.clear();
562 return Ok(false);
563 }
564
565 if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
569 self.reconcile_schema(&first.schema());
570 }
571
572 let candidates = round_float_columns(&candidates);
574
575 let delta = self.compute_delta(&candidates, task_ctx.as_ref()).await?;
577
578 if delta.is_empty() || delta.iter().all(|b| b.num_rows() == 0) {
579 self.delta.clear();
580 if let Some(ref mut agg) = self.monotonic_agg {
582 agg.snapshot();
583 }
584 return Ok(false);
585 }
586
587 let delta_bytes: usize = delta.iter().map(batch_byte_size).sum();
589 if self.facts_bytes + delta_bytes > self.max_derived_bytes {
590 return Err(datafusion::error::DataFusionError::Execution(
591 LocyRuntimeError::MemoryLimitExceeded {
592 rule: self.rule_name.clone(),
593 bytes: self.facts_bytes + delta_bytes,
594 limit: self.max_derived_bytes,
595 }
596 .to_string(),
597 ));
598 }
599
600 if let Some(ref mut agg) = self.monotonic_agg {
602 agg.snapshot();
603 agg.update(
604 &self.key_column_indices,
605 &delta,
606 self.strict_probability_domain,
607 self.semiring_kind,
608 )?;
609 }
610
611 self.facts_bytes += delta_bytes;
613 self.facts.extend(delta.iter().cloned());
614 self.delta = delta;
615
616 Ok(true)
617 }
618
619 async fn compute_delta(
626 &mut self,
627 candidates: &[RecordBatch],
628 task_ctx: Option<&Arc<TaskContext>>,
629 ) -> DFResult<Vec<RecordBatch>> {
630 let total_existing: usize = self.facts.iter().map(|b| b.num_rows()).sum();
631 if total_existing >= DEDUP_ANTI_JOIN_THRESHOLD
632 && let Some(ctx) = task_ctx
633 {
634 return arrow_left_anti_dedup(candidates.to_vec(), &self.facts, &self.schema, ctx)
635 .await;
636 }
637 if let Some(ref mut rd) = self.row_dedup {
638 rd.compute_delta(candidates, &self.schema)
639 } else {
640 self.compute_delta_legacy(candidates)
641 }
642 }
643
644 fn compute_delta_legacy(&self, candidates: &[RecordBatch]) -> DFResult<Vec<RecordBatch>> {
648 let mut existing: HashSet<Vec<ScalarKey>> = HashSet::new();
650 for batch in &self.facts {
651 for row_idx in 0..batch.num_rows() {
652 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
653 existing.insert(key);
654 }
655 }
656
657 let mut delta_batches = Vec::new();
658 for batch in candidates {
659 if batch.num_rows() == 0 {
660 continue;
661 }
662 let mut keep = Vec::with_capacity(batch.num_rows());
664 for row_idx in 0..batch.num_rows() {
665 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
666 keep.push(!existing.contains(&key));
667 }
668
669 for (row_idx, kept) in keep.iter_mut().enumerate() {
671 if *kept {
672 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
673 if !existing.insert(key) {
674 *kept = false;
675 }
676 }
677 }
678
679 let keep_mask = arrow_array::BooleanArray::from(keep);
680 let new_rows = batch
681 .columns()
682 .iter()
683 .map(|col| {
684 arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
685 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
686 })
687 })
688 .collect::<DFResult<Vec<_>>>()?;
689
690 if new_rows.first().is_some_and(|c| !c.is_empty()) {
691 let filtered =
692 RecordBatch::try_new(Arc::clone(&self.schema), new_rows).map_err(|e| {
693 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
694 })?;
695 delta_batches.push(filtered);
696 }
697 }
698
699 Ok(delta_batches)
700 }
701
702 pub fn is_converged(&self) -> bool {
704 let delta_empty = self.delta.is_empty() || self.delta.iter().all(|b| b.num_rows() == 0);
705 let agg_stable = self.monotonic_agg.as_ref().is_none_or(|a| a.is_stable());
706 delta_empty && agg_stable
707 }
708
709 pub fn all_facts(&self) -> &[RecordBatch] {
711 &self.facts
712 }
713
714 pub fn all_delta(&self) -> &[RecordBatch] {
716 &self.delta
717 }
718
719 pub fn into_facts(self) -> Vec<RecordBatch> {
721 self.facts
722 }
723
724 pub fn merge_best_by(
735 &mut self,
736 candidates: Vec<RecordBatch>,
737 sort_criteria: &[SortCriterion],
738 ) -> DFResult<bool> {
739 if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
740 self.delta.clear();
741 return Ok(false);
742 }
743
744 if let Some(first) = candidates.iter().find(|b| b.num_rows() > 0) {
746 self.reconcile_schema(&first.schema());
747 }
748
749 let candidates = round_float_columns(&candidates);
751
752 let old_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> =
754 self.build_key_criteria_map(sort_criteria);
755
756 let mut all_batches = self.facts.clone();
758 all_batches.extend(candidates);
759 let all_batches: Vec<_> = all_batches
760 .into_iter()
761 .filter(|b| b.num_rows() > 0)
762 .collect();
763 if all_batches.is_empty() {
764 self.delta.clear();
765 return Ok(false);
766 }
767
768 let combined = arrow::compute::concat_batches(&self.schema, &all_batches)
769 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
770
771 if combined.num_rows() == 0 {
772 self.delta.clear();
773 return Ok(false);
774 }
775
776 let mut sort_columns = Vec::new();
779 for &ki in &self.key_column_indices {
780 if ki >= combined.num_columns() {
781 continue;
782 }
783 sort_columns.push(arrow::compute::SortColumn {
784 values: Arc::clone(combined.column(ki)),
785 options: Some(arrow::compute::SortOptions {
786 descending: false,
787 nulls_first: false,
788 }),
789 });
790 }
791 for criterion in sort_criteria {
792 if criterion.col_index >= combined.num_columns() {
793 continue;
794 }
795 sort_columns.push(arrow::compute::SortColumn {
796 values: Arc::clone(combined.column(criterion.col_index)),
797 options: Some(arrow::compute::SortOptions {
798 descending: !criterion.ascending,
799 nulls_first: criterion.nulls_first,
800 }),
801 });
802 }
803
804 let sorted_indices =
805 arrow::compute::lexsort_to_indices(&sort_columns, None).map_err(arrow_err)?;
806 let sorted_columns: Vec<_> = combined
807 .columns()
808 .iter()
809 .map(|col| arrow::compute::take(col.as_ref(), &sorted_indices, None))
810 .collect::<Result<Vec<_>, _>>()
811 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
812 let sorted = RecordBatch::try_new(Arc::clone(&self.schema), sorted_columns)
813 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
814
815 let mut keep_indices: Vec<u32> = Vec::new();
817 let mut prev_key: Option<Vec<ScalarKey>> = None;
818 for row_idx in 0..sorted.num_rows() {
819 let key = extract_scalar_key(&sorted, &self.key_column_indices, row_idx);
820 let is_new_group = match &prev_key {
821 None => true,
822 Some(prev) => *prev != key,
823 };
824 if is_new_group {
825 keep_indices.push(row_idx as u32);
826 prev_key = Some(key);
827 }
828 }
829
830 let keep_array = arrow_array::UInt32Array::from(keep_indices);
831 let output_columns: Vec<_> = sorted
832 .columns()
833 .iter()
834 .map(|col| arrow::compute::take(col.as_ref(), &keep_array, None))
835 .collect::<Result<Vec<_>, _>>()
836 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
837 let pruned = RecordBatch::try_new(Arc::clone(&self.schema), output_columns)
838 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
839
840 let new_best: HashMap<Vec<ScalarKey>, Vec<ScalarKey>> = {
842 let mut map = HashMap::new();
843 for row_idx in 0..pruned.num_rows() {
844 let key = extract_scalar_key(&pruned, &self.key_column_indices, row_idx);
845 let criteria: Vec<ScalarKey> = sort_criteria
846 .iter()
847 .flat_map(|c| extract_scalar_key(&pruned, &[c.col_index], row_idx))
848 .collect();
849 map.insert(key, criteria);
850 }
851 map
852 };
853 let changed = old_best != new_best;
854
855 tracing::debug!(
856 rule = %self.rule_name,
857 old_keys = old_best.len(),
858 new_keys = new_best.len(),
859 changed = changed,
860 "BEST BY merge"
861 );
862
863 self.facts_bytes = batch_byte_size(&pruned);
865 self.facts = vec![pruned];
866 if changed {
867 self.delta = self.facts.clone();
870 } else {
871 self.delta.clear();
872 }
873
874 self.row_dedup = RowDedupState::try_new(&self.schema);
876 if let Some(ref mut rd) = self.row_dedup {
877 rd.ingest_existing(&self.facts, &self.schema);
878 }
879
880 Ok(changed)
881 }
882
883 fn build_key_criteria_map(
885 &self,
886 sort_criteria: &[SortCriterion],
887 ) -> HashMap<Vec<ScalarKey>, Vec<ScalarKey>> {
888 let mut map = HashMap::new();
889 for batch in &self.facts {
890 for row_idx in 0..batch.num_rows() {
891 let key = extract_scalar_key(batch, &self.key_column_indices, row_idx);
892 let criteria: Vec<ScalarKey> = sort_criteria
893 .iter()
894 .flat_map(|c| extract_scalar_key(batch, &[c.col_index], row_idx))
895 .collect();
896 map.insert(key, criteria);
897 }
898 }
899 map
900 }
901}
902
903fn batch_byte_size(batch: &RecordBatch) -> usize {
905 batch
906 .columns()
907 .iter()
908 .map(|col| col.get_buffer_memory_size())
909 .sum()
910}
911
912fn round_float_columns(batches: &[RecordBatch]) -> Vec<RecordBatch> {
918 batches
919 .iter()
920 .map(|batch| {
921 let schema = batch.schema();
922 let has_float = schema
923 .fields()
924 .iter()
925 .any(|f| *f.data_type() == arrow_schema::DataType::Float64);
926 if !has_float {
927 return batch.clone();
928 }
929
930 let columns: Vec<arrow_array::ArrayRef> = batch
931 .columns()
932 .iter()
933 .enumerate()
934 .map(|(i, col)| {
935 if *schema.field(i).data_type() == arrow_schema::DataType::Float64 {
936 let arr = col
937 .as_any()
938 .downcast_ref::<arrow_array::Float64Array>()
939 .unwrap();
940 let rounded: arrow_array::Float64Array = arr
941 .iter()
942 .map(|v| v.map(|f| (f * 1e12).round() / 1e12))
943 .collect();
944 Arc::new(rounded) as arrow_array::ArrayRef
945 } else {
946 Arc::clone(col)
947 }
948 })
949 .collect();
950
951 RecordBatch::try_new(schema, columns).unwrap_or_else(|_| batch.clone())
952 })
953 .collect()
954}
955
956const DEDUP_ANTI_JOIN_THRESHOLD: usize = 300;
966
967fn dedup_batches_all_columns(
982 batches: Vec<RecordBatch>,
983 schema: &SchemaRef,
984) -> DFResult<Vec<RecordBatch>> {
985 let fields: Vec<SortField> = schema
986 .fields()
987 .iter()
988 .map(|f| SortField::new(f.data_type().clone()))
989 .collect();
990 let Ok(converter) = RowConverter::new(fields) else {
994 return Ok(batches);
995 };
996 let mut seen: HashSet<Box<[u8]>> = HashSet::new();
997 let mut out = Vec::with_capacity(batches.len());
998 for batch in batches {
999 if batch.num_rows() == 0 {
1000 continue;
1001 }
1002 let rows = converter
1003 .convert_columns(batch.columns())
1004 .map_err(arrow_err)?;
1005 let mut keep = Vec::with_capacity(batch.num_rows());
1006 for row_idx in 0..batch.num_rows() {
1007 let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
1008 keep.push(seen.insert(row_bytes));
1009 }
1010 let keep_mask = arrow_array::BooleanArray::from(keep);
1011 let cols = batch
1012 .columns()
1013 .iter()
1014 .map(|c| arrow::compute::filter(c.as_ref(), &keep_mask).map_err(arrow_err))
1015 .collect::<DFResult<Vec<_>>>()?;
1016 if cols.first().is_some_and(|c| !c.is_empty()) {
1017 out.push(RecordBatch::try_new(Arc::clone(schema), cols).map_err(arrow_err)?);
1018 }
1019 }
1020 Ok(out)
1021}
1022
1023async fn arrow_left_anti_dedup(
1024 candidates: Vec<RecordBatch>,
1025 existing: &[RecordBatch],
1026 schema: &SchemaRef,
1027 task_ctx: &Arc<TaskContext>,
1028) -> DFResult<Vec<RecordBatch>> {
1029 if existing.is_empty() || existing.iter().all(|b| b.num_rows() == 0) {
1030 return dedup_batches_all_columns(candidates, schema);
1033 }
1034
1035 let left: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(candidates, Arc::clone(schema)));
1036 let right: Arc<dyn ExecutionPlan> =
1037 Arc::new(InMemoryExec::new(existing.to_vec(), Arc::clone(schema)));
1038
1039 let on: Vec<(
1040 Arc<dyn datafusion::physical_plan::PhysicalExpr>,
1041 Arc<dyn datafusion::physical_plan::PhysicalExpr>,
1042 )> = schema
1043 .fields()
1044 .iter()
1045 .enumerate()
1046 .map(|(i, field)| {
1047 let l: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
1048 datafusion::physical_plan::expressions::Column::new(field.name(), i),
1049 );
1050 let r: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
1051 datafusion::physical_plan::expressions::Column::new(field.name(), i),
1052 );
1053 (l, r)
1054 })
1055 .collect();
1056
1057 if on.is_empty() {
1058 return Ok(vec![]);
1059 }
1060
1061 let join = HashJoinExec::try_new(
1062 left,
1063 right,
1064 on,
1065 None,
1066 &JoinType::LeftAnti,
1067 None,
1068 PartitionMode::CollectLeft,
1069 datafusion::common::NullEquality::NullEqualsNull,
1070 false,
1075 )?;
1076
1077 let join_arc: Arc<dyn ExecutionPlan> = Arc::new(join);
1078 let anti = collect_all_partitions(&join_arc, task_ctx.clone()).await?;
1081 dedup_batches_all_columns(anti, schema)
1082}
1083
1084#[derive(Debug, Clone)]
1090pub struct IsRefBinding {
1091 pub derived_scan_index: usize,
1093 pub rule_name: String,
1095 pub is_self_ref: bool,
1097 pub negated: bool,
1099 pub anti_join_cols: Vec<(String, String)>,
1105 pub target_has_prob: bool,
1107 pub target_prob_col: Option<String>,
1109 pub provenance_join_cols: Vec<(String, String)>,
1114}
1115
1116#[derive(Debug)]
1118pub struct FixpointClausePlan {
1119 pub body_logical: LogicalPlan,
1121 pub is_ref_bindings: Vec<IsRefBinding>,
1123 pub priority: Option<i64>,
1125 pub along_bindings: Vec<String>,
1127 pub model_invocations: Vec<ModelInvocation>,
1131}
1132
1133#[derive(Debug)]
1135pub struct FixpointRulePlan {
1136 pub name: String,
1138 pub clauses: Vec<FixpointClausePlan>,
1140 pub yield_schema: SchemaRef,
1142 pub key_column_indices: Vec<usize>,
1144 pub priority: Option<i64>,
1146 pub has_fold: bool,
1148 pub fold_bindings: Vec<FoldBinding>,
1150 pub having: Vec<Expr>,
1152 pub has_best_by: bool,
1154 pub best_by_criteria: Vec<SortCriterion>,
1156 pub has_priority: bool,
1158 pub deterministic: bool,
1162 pub prob_column_name: Option<String>,
1164 pub non_linear: bool,
1171}
1172
1173#[expect(clippy::too_many_arguments, reason = "Fixpoint loop needs all context")]
1182async fn run_fixpoint_loop(
1183 rules: Vec<FixpointRulePlan>,
1184 max_iterations: usize,
1185 timeout: Duration,
1186 graph_ctx: Arc<GraphExecutionContext>,
1187 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
1188 storage: Arc<StorageManager>,
1189 schema_info: Arc<UniSchema>,
1190 params: HashMap<String, Value>,
1191 registry: Arc<DerivedScanRegistry>,
1192 output_schema: SchemaRef,
1193 max_derived_bytes: usize,
1194 derivation_tracker: Option<Arc<ProvenanceStore>>,
1195 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
1196 strict_probability_domain: bool,
1197 probability_epsilon: f64,
1198 exact_probability: bool,
1199 max_bdd_variables: usize,
1200 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
1201 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
1202 top_k_proofs: usize,
1203 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
1204 semiring_kind: SemiringKind,
1205 classifier_registry: Arc<ClassifierRegistry>,
1206 classifier_cache: Option<Arc<ModelInvocationCache>>,
1207 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
1208) -> DFResult<Vec<RecordBatch>> {
1209 let start = Instant::now();
1210 let task_ctx = session_ctx.read().task_ctx();
1211
1212 if semiring_kind == SemiringKind::MaxMinProb {
1217 let mut warnings = warnings_slot.write().unwrap_or_else(|e| e.into_inner());
1218 let mut already_warned: HashSet<String> = warnings
1219 .iter()
1220 .filter(|w| w.code == RuntimeWarningCode::FuzzyNotProbabilistic)
1221 .map(|w| w.rule_name.clone())
1222 .collect();
1223 for rule in &rules {
1224 if rule.prob_column_name.is_some() && !already_warned.contains(&rule.name) {
1225 warnings.push(RuntimeWarning {
1226 code: RuntimeWarningCode::FuzzyNotProbabilistic,
1227 message: format!(
1228 "rule '{}' carries a PROB column but is being evaluated under \
1229 the MaxMinProb (fuzzy / Viterbi) semiring; outputs are fuzzy \
1230 truth values, not probabilities",
1231 rule.name
1232 ),
1233 rule_name: rule.name.clone(),
1234 variable_count: None,
1235 key_group: None,
1236 });
1237 already_warned.insert(rule.name.clone());
1238 }
1239 }
1240 }
1241
1242 let mut states: Vec<FixpointState> = rules
1244 .iter()
1245 .map(|rule| {
1246 let monotonic_agg = if !rule.fold_bindings.is_empty() {
1247 let bindings: Vec<MonotonicFoldBinding> = rule
1248 .fold_bindings
1249 .iter()
1250 .map(|fb| MonotonicFoldBinding {
1251 fold_name: fb.output_name.clone(),
1252 aggregate: std::sync::Arc::clone(&fb.aggregate),
1253 input_col_index: fb.input_col_index,
1254 input_col_name: fb.input_col_name.clone(),
1255 })
1256 .collect();
1257 Some(MonotonicAggState::new(bindings))
1258 } else {
1259 None
1260 };
1261 FixpointState::new_with_semiring(
1262 rule.name.clone(),
1263 Arc::clone(&rule.yield_schema),
1264 rule.key_column_indices.clone(),
1265 max_derived_bytes,
1266 monotonic_agg,
1267 strict_probability_domain,
1268 semiring_kind,
1269 )
1270 })
1271 .collect();
1272
1273 let mut converged = false;
1275 let mut total_iters = 0usize;
1276 for iteration in 0..max_iterations {
1277 total_iters = iteration + 1;
1278 tracing::debug!("fixpoint iteration {}", iteration);
1279 let mut any_changed = false;
1280
1281 for rule_idx in 0..rules.len() {
1282 let rule = &rules[rule_idx];
1283
1284 update_derived_scan_handles(®istry, &states, rule_idx, &rules);
1286
1287 let mut all_candidates = Vec::new();
1289 let mut clause_candidates: Vec<Vec<RecordBatch>> = Vec::new();
1290 for clause in &rule.clauses {
1291 let mut batches = execute_subplan(
1297 &clause.body_logical,
1298 ¶ms,
1299 &HashMap::new(),
1300 &graph_ctx,
1301 &session_ctx,
1302 &storage,
1303 &schema_info,
1304 None, )
1306 .await?;
1307 for binding in &clause.is_ref_bindings {
1309 if binding.negated
1310 && !binding.anti_join_cols.is_empty()
1311 && let Some(entry) = registry.get(binding.derived_scan_index)
1312 {
1313 let neg_facts = entry.data.read().clone();
1314 if !neg_facts.is_empty() {
1315 if binding.target_has_prob && rule.prob_column_name.is_some() {
1316 let complement_col =
1318 format!("__prob_complement_{}", binding.rule_name);
1319 if let Some(prob_col) = &binding.target_prob_col {
1320 batches = apply_prob_complement_composite(
1321 batches,
1322 &neg_facts,
1323 &binding.anti_join_cols,
1324 prob_col,
1325 &complement_col,
1326 )?;
1327 } else {
1328 batches = apply_anti_join_composite(
1330 batches,
1331 &neg_facts,
1332 &binding.anti_join_cols,
1333 )?;
1334 }
1335 } else {
1336 batches = apply_anti_join_composite(
1338 batches,
1339 &neg_facts,
1340 &binding.anti_join_cols,
1341 )?;
1342 }
1343 }
1344 }
1345 }
1346 let complement_cols: Vec<String> = if !batches.is_empty() {
1348 batches[0]
1349 .schema()
1350 .fields()
1351 .iter()
1352 .filter(|f| f.name().starts_with("__prob_complement_"))
1353 .map(|f| f.name().clone())
1354 .collect()
1355 } else {
1356 vec![]
1357 };
1358 if !complement_cols.is_empty() {
1359 batches = multiply_prob_factors(
1360 batches,
1361 rule.prob_column_name.as_deref(),
1362 &complement_cols,
1363 )?;
1364 }
1365
1366 clause_candidates.push(batches.clone());
1367 all_candidates.extend(batches);
1368 }
1369
1370 let changed = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
1374 states[rule_idx].merge_best_by(all_candidates, &rule.best_by_criteria)?
1375 } else {
1376 states[rule_idx]
1377 .merge_delta(all_candidates, Some(Arc::clone(&task_ctx)))
1378 .await?
1379 };
1380 if changed {
1381 any_changed = true;
1382 if let Some(ref tracker) = derivation_tracker {
1384 record_provenance(
1385 ProvenanceCtx {
1386 tracker,
1387 registry: ®istry,
1388 warnings_slot: &warnings_slot,
1389 },
1390 rule,
1391 &states[rule_idx],
1392 &clause_candidates,
1393 iteration,
1394 top_k_proofs,
1395 ClassifierRefs {
1396 registry: &classifier_registry,
1397 cache: classifier_cache.as_ref(),
1398 provenance_store: classifier_provenance_store.as_ref(),
1399 },
1400 )
1401 .await;
1402 }
1403 }
1404 }
1405
1406 if !any_changed && states.iter().all(|s| s.is_converged()) {
1408 tracing::debug!("fixpoint converged after {} iterations", iteration + 1);
1409 converged = true;
1410 break;
1411 }
1412
1413 if start.elapsed() > timeout {
1415 tracing::warn!(
1416 "fixpoint timeout after {} iterations; returning partial results",
1417 iteration + 1,
1418 );
1419 interruption::set(&timeout_flag, interruption::TIMEOUT);
1420 break;
1421 }
1422 }
1423
1424 if let Ok(mut counts) = iteration_counts.write() {
1426 for rule in &rules {
1427 counts.insert(rule.name.clone(), total_iters);
1428 }
1429 }
1430
1431 if !converged && interruption::reason(&timeout_flag).is_none() {
1436 tracing::warn!(
1437 "fixpoint did not converge after {max_iterations} iterations; returning partial results",
1438 );
1439 interruption::set(&timeout_flag, interruption::ITERATION_LIMIT);
1440 }
1441
1442 let task_ctx = session_ctx.read().task_ctx();
1444 let mut all_output = Vec::new();
1445
1446 for (rule_idx, state) in states.into_iter().enumerate() {
1447 let rule = &rules[rule_idx];
1448 let mut facts = state.into_facts();
1449 if facts.is_empty() {
1450 continue;
1451 }
1452
1453 let shared_info = if semiring_kind == SemiringKind::MaxMinProb {
1471 None
1472 } else if let Some(ref tracker) = derivation_tracker {
1473 detect_shared_lineage(rule, &facts, tracker, &warnings_slot, semiring_kind)
1474 } else {
1475 None
1476 };
1477
1478 if exact_probability
1480 && let Some(ref info) = shared_info
1481 && let Some(ref tracker) = derivation_tracker
1482 {
1483 facts = apply_exact_wmc(
1484 facts,
1485 rule,
1486 info,
1487 tracker,
1488 max_bdd_variables,
1489 &warnings_slot,
1490 &approximate_slot,
1491 )?;
1492 }
1493
1494 let processed = apply_post_fixpoint_chain(
1495 facts,
1496 rule,
1497 &task_ctx,
1498 strict_probability_domain,
1499 probability_epsilon,
1500 semiring_kind,
1501 derivation_tracker.as_ref().map(Arc::clone),
1502 top_k_proofs,
1503 Some(Arc::clone(®istry)),
1504 )
1505 .await?;
1506 all_output.extend(processed);
1507 }
1508
1509 if all_output.is_empty() {
1511 all_output.push(RecordBatch::new_empty(output_schema));
1512 }
1513
1514 Ok(all_output)
1515}
1516
1517pub(crate) struct ClassifierRefs<'a> {
1529 pub registry: &'a Arc<ClassifierRegistry>,
1530 pub cache: Option<&'a Arc<uni_locy::ModelInvocationCache>>,
1531 pub provenance_store: Option<&'a Arc<uni_locy::NeuralProvenanceStore>>,
1538}
1539
1540pub(crate) struct ProvenanceCtx<'a> {
1546 pub tracker: &'a Arc<ProvenanceStore>,
1547 pub registry: &'a Arc<DerivedScanRegistry>,
1548 pub warnings_slot: &'a Arc<StdRwLock<Vec<RuntimeWarning>>>,
1549}
1550
1551async fn record_provenance(
1552 prov: ProvenanceCtx<'_>,
1553 rule: &FixpointRulePlan,
1554 state: &FixpointState,
1555 clause_candidates: &[Vec<RecordBatch>],
1556 iteration: usize,
1557 top_k_proofs: usize,
1558 classifiers: ClassifierRefs<'_>,
1559) {
1560 let tracker = prov.tracker;
1561 let registry = prov.registry;
1562 let warnings_slot = prov.warnings_slot;
1563 let classifier_registry = classifiers.registry;
1564 let classifier_cache = classifiers.cache;
1565 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
1566
1567 let base_probs = if top_k_proofs > 0 {
1569 tracker.base_fact_probs()
1570 } else {
1571 HashMap::new()
1572 };
1573
1574 let mut topk_acc = TopKProofAccumulator::new();
1575
1576 for delta_batch in state.all_delta() {
1577 for row_idx in 0..delta_batch.num_rows() {
1578 let row_hash = format!(
1579 "{:?}",
1580 extract_scalar_key(delta_batch, &all_indices, row_idx)
1581 )
1582 .into_bytes();
1583 let fact_row = batch_row_to_value_map(delta_batch, row_idx);
1584 let clause_index =
1585 find_clause_for_row(delta_batch, row_idx, &all_indices, clause_candidates);
1586
1587 let support = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
1588
1589 let proof_probability = if top_k_proofs > 0 {
1590 compute_proof_probability(&support, &base_probs)
1591 } else {
1592 None
1593 };
1594
1595 let entry = ProvenanceAnnotation {
1596 rule_name: rule.name.clone(),
1597 clause_index,
1598 support,
1599 along_values: {
1600 let along_names: Vec<String> = rule
1601 .clauses
1602 .get(clause_index)
1603 .map(|c| c.along_bindings.clone())
1604 .unwrap_or_default();
1605 along_names
1606 .iter()
1607 .filter_map(|name| fact_row.get(name).map(|v| (name.clone(), v.clone())))
1608 .collect()
1609 },
1610 iteration,
1611 fact_row: fact_row.clone(),
1612 proof_probability,
1613 neural_calls: collect_neural_calls_for_row(
1614 rule,
1615 clause_index,
1616 &fact_row,
1617 classifier_registry,
1618 classifier_cache,
1619 classifiers.provenance_store,
1620 )
1621 .await,
1622 };
1623 if top_k_proofs > 0 {
1624 topk_acc.accumulate(&entry, &row_hash);
1625 tracker.record_top_k(row_hash, entry, top_k_proofs);
1626 } else {
1627 tracker.record(row_hash, entry);
1628 }
1629 }
1630 }
1631
1632 topk_acc.emit_warning_if_any(rule, top_k_proofs, warnings_slot);
1633}
1634
1635struct TopKProofAccumulator {
1642 per_fact: HashMap<Vec<u8>, Vec<uni_locy::Proof>>,
1643 base_rv_interner: HashMap<Vec<u8>, uni_locy::BaseRv>,
1644 next_rv: u32,
1645}
1646
1647impl TopKProofAccumulator {
1648 fn new() -> Self {
1649 Self {
1650 per_fact: HashMap::new(),
1651 base_rv_interner: HashMap::new(),
1652 next_rv: 0,
1653 }
1654 }
1655
1656 fn accumulate(&mut self, entry: &ProvenanceAnnotation, row_hash: &[u8]) {
1657 let mut base_rvs = uni_locy::BaseRvSet::empty();
1658 for term in &entry.support {
1659 let rv = *self
1660 .base_rv_interner
1661 .entry(term.base_fact_id.clone())
1662 .or_insert_with(|| {
1663 let r = uni_locy::BaseRv(self.next_rv);
1664 self.next_rv += 1;
1665 r
1666 });
1667 base_rvs.insert(rv);
1668 }
1669 self.per_fact
1670 .entry(row_hash.to_vec())
1671 .or_default()
1672 .push(uni_locy::Proof {
1673 weight: entry.proof_probability.unwrap_or(0.0),
1674 base_rvs,
1675 neural_calls: Vec::new(),
1676 });
1677 }
1678
1679 fn emit_warning_if_any(
1680 &self,
1681 rule: &FixpointRulePlan,
1682 top_k_proofs: usize,
1683 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
1684 ) {
1685 if top_k_proofs == 0 || self.per_fact.is_empty() {
1686 return;
1687 }
1688 let crossed_facts = self
1689 .per_fact
1690 .values()
1691 .filter(|proofs| {
1692 let (_kept, notice) =
1693 uni_locy::merge_top_k_runtime(Vec::new(), (*proofs).clone(), top_k_proofs);
1694 notice == uni_locy::PruneNotice::CrossedDependency
1695 })
1696 .count();
1697 if crossed_facts == 0 {
1698 return;
1699 }
1700 let Ok(mut w) = warnings_slot.write() else {
1701 return;
1702 };
1703 let already = w.iter().any(|rw| {
1704 matches!(
1705 rw.code,
1706 uni_locy::types::RuntimeWarningCode::TopKPruningCrossedDependency
1707 ) && rw.rule_name == rule.name
1708 });
1709 if already {
1710 return;
1711 }
1712 w.push(RuntimeWarning {
1713 code: uni_locy::types::RuntimeWarningCode::TopKPruningCrossedDependency,
1714 rule_name: rule.name.clone(),
1715 message: format!(
1716 "rule '{}': top-K proof pruning (k={}) discarded {} fact(s) \
1717 whose dependencies overlap retained proofs. The retained \
1718 top-{} under-counts the true joint probability for those \
1719 facts (Scallop, Huang et al. 2021). Increase k to recover.",
1720 rule.name, top_k_proofs, crossed_facts, top_k_proofs
1721 ),
1722 variable_count: None,
1723 key_group: None,
1724 });
1725 }
1726}
1727
1728async fn collect_neural_calls_for_row(
1750 rule: &FixpointRulePlan,
1751 clause_index: usize,
1752 fact_row: &uni_locy::FactRow,
1753 classifier_registry: &Arc<ClassifierRegistry>,
1754 classifier_cache: Option<&Arc<uni_locy::ModelInvocationCache>>,
1755 provenance_store: Option<&Arc<uni_locy::NeuralProvenanceStore>>,
1756) -> Vec<uni_locy::NeuralProvenance> {
1757 let Some(clause) = rule.clauses.get(clause_index) else {
1758 return Vec::new();
1759 };
1760 if clause.model_invocations.is_empty() {
1761 return Vec::new();
1762 }
1763 let mut out = Vec::with_capacity(clause.model_invocations.len());
1764 for invocation in &clause.model_invocations {
1765 let mut features = std::collections::HashMap::new();
1776 for (binding_name, feat_expr) in invocation
1777 .feature_names
1778 .iter()
1779 .zip(invocation.feature_exprs.iter())
1780 {
1781 features.insert(
1782 binding_name.clone(),
1783 eval_feature_expr_against_fact_row(feat_expr, fact_row),
1784 );
1785 }
1786 let input = uni_locy::ClassifyInput { features };
1787 let input_hash = input.stable_hash();
1788
1789 if let Some(store) = provenance_store
1796 && let Some(record) = store.get(&invocation.model_name, input_hash)
1797 {
1798 out.push(uni_locy::NeuralProvenance {
1799 model_name: invocation.model_name.clone(),
1800 raw_probability: record.raw_probability,
1801 calibrated_probability: record.calibrated_probability,
1802 confidence_band: record.confidence_band,
1803 });
1804 continue;
1805 }
1806
1807 let Some(classifier) = classifier_registry.get(&invocation.model_name) else {
1812 continue;
1813 };
1814 let raw = if let Some(v) =
1815 classifier_cache.and_then(|c| c.get(&invocation.model_name, input_hash))
1816 {
1817 v
1818 } else {
1819 match classifier.classify(std::slice::from_ref(&input)).await {
1820 Ok(probs) => {
1821 let v = probs.first().copied().unwrap_or(0.0);
1822 if let Some(c) = classifier_cache {
1823 c.insert(&invocation.model_name, input_hash, v);
1824 }
1825 v
1826 }
1827 Err(_) => continue,
1828 }
1829 };
1830 let calibrator = classifier.get_calibrator();
1831 let calibrated_probability = calibrator.as_ref().map(|_| raw);
1832 let confidence_band = calibrator.as_ref().and_then(|c| c.confidence_band(raw));
1833 out.push(uni_locy::NeuralProvenance {
1834 model_name: invocation.model_name.clone(),
1835 raw_probability: raw,
1836 calibrated_probability,
1837 confidence_band,
1838 });
1839 }
1840 out
1841}
1842
1843fn eval_feature_expr_against_fact_row(
1851 expr: &uni_cypher::ast::Expr,
1852 fact_row: &uni_locy::FactRow,
1853) -> uni_locy::FeatureValue {
1854 use uni_cypher::ast::Expr;
1855 use uni_locy::FeatureValue;
1856 let value_to_feature = |v: Option<&uni_common::Value>| -> FeatureValue {
1857 match v {
1858 Some(uni_common::Value::Float(f)) => FeatureValue::Float(*f),
1859 Some(uni_common::Value::Int(i)) => FeatureValue::Int(*i),
1860 Some(uni_common::Value::Bool(b)) => FeatureValue::Bool(*b),
1861 Some(uni_common::Value::String(s)) => FeatureValue::String(s.clone()),
1862 Some(uni_common::Value::Node(n)) => {
1863 FeatureValue::Int(n.vid.as_u64() as i64)
1865 }
1866 _ => FeatureValue::Null,
1867 }
1868 };
1869 let resolve_value = |sub: &Expr| -> uni_common::Value {
1873 match sub {
1874 Expr::Variable(name) => fact_row
1875 .get(name)
1876 .cloned()
1877 .unwrap_or(uni_common::Value::Null),
1878 Expr::Property(boxed, prop) if matches!(boxed.as_ref(), Expr::Variable(_)) => {
1879 let Expr::Variable(v) = boxed.as_ref() else {
1880 unreachable!()
1881 };
1882 let key = format!("{}.{}", v, prop);
1883 if let Some(val) = fact_row.get(&key) {
1884 return val.clone();
1885 }
1886 if let Some(uni_common::Value::Node(n)) = fact_row.get(v) {
1887 return n
1888 .properties
1889 .get(prop)
1890 .cloned()
1891 .unwrap_or(uni_common::Value::Null);
1892 }
1893 uni_common::Value::Null
1894 }
1895 Expr::Literal(lit) => lit.to_value(),
1896 Expr::List(items) => {
1897 let mut out = Vec::with_capacity(items.len());
1898 for it in items {
1899 out.push(match it {
1900 Expr::Literal(lit) => lit.to_value(),
1901 _ => uni_common::Value::Null,
1902 });
1903 }
1904 uni_common::Value::List(out)
1905 }
1906 _ => uni_common::Value::Null,
1907 }
1908 };
1909
1910 match expr {
1911 Expr::Variable(name) => value_to_feature(fact_row.get(name)),
1912 Expr::Property(boxed, prop) => {
1913 if let Expr::Variable(v) = boxed.as_ref() {
1914 let key = format!("{}.{}", v, prop);
1916 if let Some(val) = fact_row.get(&key) {
1917 return value_to_feature(Some(val));
1918 }
1919 let hidden_key = format!("__feat_{}_{}", v, prop);
1929 if let Some(val) = fact_row.get(&hidden_key) {
1930 return value_to_feature(Some(val));
1931 }
1932 if let Some(uni_common::Value::Node(n)) = fact_row.get(v) {
1936 return value_to_feature(n.properties.get(prop));
1937 }
1938 }
1939 FeatureValue::Null
1940 }
1941 Expr::FunctionCall { name, args, .. } if name == "similar_to" && args.len() == 2 => {
1942 let lv = resolve_value(&args[0]);
1943 let rv = resolve_value(&args[1]);
1944 match crate::query::similar_to::eval_similar_to_pure(&lv, &rv) {
1945 Ok(uni_common::Value::Float(f)) => FeatureValue::Float(f),
1946 _ => FeatureValue::Null,
1947 }
1948 }
1949 Expr::FunctionCall { name, .. }
1964 if matches!(
1965 name.as_str(),
1966 "degree_centrality"
1967 | "pagerank_score"
1968 | "closeness_centrality"
1969 | "betweenness_centrality"
1970 | "eigenvector_centrality"
1971 | "harmonic_centrality"
1972 | "katz_centrality"
1973 | "avg_neighbor"
1974 | "max_neighbor"
1975 | "sum_neighbor"
1976 ) =>
1977 {
1978 FeatureValue::Null
1979 }
1980 _ => FeatureValue::Null,
1981 }
1982}
1983
1984fn collect_is_ref_inputs(
1985 rule: &FixpointRulePlan,
1986 clause_index: usize,
1987 delta_batch: &RecordBatch,
1988 row_idx: usize,
1989 registry: &Arc<DerivedScanRegistry>,
1990) -> Vec<ProofTerm> {
1991 let clause = match rule.clauses.get(clause_index) {
1992 Some(c) => c,
1993 None => return vec![],
1994 };
1995
1996 let mut inputs = Vec::new();
1997 let delta_schema = delta_batch.schema();
1998
1999 for binding in &clause.is_ref_bindings {
2000 if binding.negated {
2001 continue;
2002 }
2003 if binding.provenance_join_cols.is_empty() {
2004 continue;
2005 }
2006
2007 let body_values: Vec<(String, ScalarKey)> = binding
2009 .provenance_join_cols
2010 .iter()
2011 .filter_map(|(body_col, _derived_col)| {
2012 let col_idx = delta_schema
2013 .fields()
2014 .iter()
2015 .position(|f| f.name() == body_col)?;
2016 let key = extract_scalar_key(delta_batch, &[col_idx], row_idx);
2017 Some((body_col.clone(), key.into_iter().next()?))
2018 })
2019 .collect();
2020
2021 if body_values.len() != binding.provenance_join_cols.len() {
2022 continue;
2023 }
2024
2025 let entry = match registry.get(binding.derived_scan_index) {
2027 Some(e) => e,
2028 None => continue,
2029 };
2030 let source_batches = entry.data.read();
2031 let source_schema = &entry.schema;
2032
2033 for src_batch in source_batches.iter() {
2035 let all_src_indices: Vec<usize> = (0..src_batch.num_columns()).collect();
2036 for src_row in 0..src_batch.num_rows() {
2037 let matches = binding.provenance_join_cols.iter().enumerate().all(
2038 |(i, (_body_col, derived_col))| {
2039 let src_col_idx = source_schema
2040 .fields()
2041 .iter()
2042 .position(|f| f.name() == derived_col);
2043 match src_col_idx {
2044 Some(idx) => {
2045 let src_key = extract_scalar_key(src_batch, &[idx], src_row);
2046 src_key.first() == Some(&body_values[i].1)
2047 }
2048 None => false,
2049 }
2050 },
2051 );
2052 if matches {
2053 let fact_hash = format!(
2054 "{:?}",
2055 extract_scalar_key(src_batch, &all_src_indices, src_row)
2056 )
2057 .into_bytes();
2058 inputs.push(ProofTerm {
2059 source_rule: binding.rule_name.clone(),
2060 base_fact_id: fact_hash,
2061 });
2062 }
2063 }
2064 }
2065 }
2066
2067 inputs
2068}
2069
2070fn collect_is_ref_inputs_for_body_row(
2092 rule: &FixpointRulePlan,
2093 delta_batch: &RecordBatch,
2094 row_idx: usize,
2095 registry: &Arc<DerivedScanRegistry>,
2096) -> Vec<ProofTerm> {
2097 let mut combined: Vec<ProofTerm> = Vec::new();
2098 for clause_index in 0..rule.clauses.len() {
2099 let part = collect_is_ref_inputs(rule, clause_index, delta_batch, row_idx, registry);
2100 combined.extend(part);
2101 }
2102 combined
2103}
2104
2105#[expect(
2124 dead_code,
2125 reason = "Fields accessed via SharedLineageInfo in detect_shared_lineage"
2126)]
2127pub(crate) struct SharedGroupRow {
2128 pub fact_hash: Vec<u8>,
2129 pub lineage: HashSet<Vec<u8>>,
2130}
2131
2132pub(crate) struct SharedLineageInfo {
2134 pub shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>>,
2136}
2137
2138pub(crate) fn fact_hash_key(batch: &RecordBatch, all_indices: &[usize], row_idx: usize) -> Vec<u8> {
2140 format!("{:?}", extract_scalar_key(batch, all_indices, row_idx)).into_bytes()
2141}
2142
2143fn detect_shared_lineage(
2146 rule: &FixpointRulePlan,
2147 pre_fold_facts: &[RecordBatch],
2148 tracker: &Arc<ProvenanceStore>,
2149 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2150 semiring_kind: SemiringKind,
2151) -> Option<SharedLineageInfo> {
2152 use uni_locy::{RuntimeWarning, RuntimeWarningCode};
2153
2154 let has_prob_fold = rule
2159 .fold_bindings
2160 .iter()
2161 .any(|fb| fb.aggregate.is_probability_aggregate());
2162 if !has_prob_fold {
2163 return None;
2164 }
2165
2166 let key_indices = &rule.key_column_indices;
2168 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2169
2170 let mut groups: HashMap<Vec<ScalarKey>, Vec<Vec<u8>>> = HashMap::new();
2171 for batch in pre_fold_facts {
2172 for row_idx in 0..batch.num_rows() {
2173 let key = extract_scalar_key(batch, key_indices, row_idx);
2174 let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
2175 groups.entry(key).or_default().push(fact_hash);
2176 }
2177 }
2178
2179 let mut shared_groups: HashMap<Vec<ScalarKey>, Vec<SharedGroupRow>> = HashMap::new();
2180 let mut any_shared = false;
2181
2182 for (key, fact_hashes) in &groups {
2184 if fact_hashes.len() < 2 {
2185 continue;
2186 }
2187
2188 let mut has_inputs = false;
2190 let mut per_row_bases: Vec<HashSet<Vec<u8>>> = Vec::new();
2191 for fh in fact_hashes {
2192 let bases = compute_lineage(fh, tracker, &mut HashSet::new());
2193 if let Some(entry) = tracker.lookup(fh)
2194 && !entry.support.is_empty()
2195 {
2196 has_inputs = true;
2197 }
2198 per_row_bases.push(bases);
2199 }
2200
2201 let shared_found = if has_inputs {
2202 let mut found = false;
2204 'outer: for i in 0..per_row_bases.len() {
2205 for j in (i + 1)..per_row_bases.len() {
2206 if !per_row_bases[i].is_disjoint(&per_row_bases[j]) {
2207 found = true;
2208 break 'outer;
2209 }
2210 }
2211 }
2212 found
2213 } else {
2214 fact_hashes.iter().any(|fh| {
2217 tracker.lookup(fh).is_some_and(|entry| {
2218 rule.clauses
2219 .get(entry.clause_index)
2220 .is_some_and(|clause| clause.is_ref_bindings.iter().any(|b| !b.negated))
2221 })
2222 })
2223 };
2224
2225 if shared_found {
2226 any_shared = true;
2227 let rows: Vec<SharedGroupRow> = fact_hashes
2229 .iter()
2230 .zip(per_row_bases)
2231 .map(|(fh, bases)| SharedGroupRow {
2232 fact_hash: fh.clone(),
2233 lineage: bases,
2234 })
2235 .collect();
2236 shared_groups.insert(key.clone(), rows);
2237 }
2238 }
2239
2240 {
2246 let mut input_to_groups: HashMap<Vec<u8>, HashSet<Vec<ScalarKey>>> = HashMap::new();
2247 for (key, fact_hashes) in &groups {
2248 for fh in fact_hashes {
2249 if let Some(entry) = tracker.lookup(fh) {
2250 for input in &entry.support {
2251 input_to_groups
2252 .entry(input.base_fact_id.clone())
2253 .or_default()
2254 .insert(key.clone());
2255 }
2256 }
2257 }
2258 }
2259 let has_cross_group = input_to_groups.values().any(|g| g.len() > 1);
2260 if has_cross_group && let Ok(mut warnings) = warnings_slot.write() {
2261 let already_warned = warnings.iter().any(|w| {
2262 w.code == RuntimeWarningCode::CrossGroupCorrelationNotExact
2263 && w.rule_name == rule.name
2264 });
2265 if !already_warned {
2266 let example =
2270 input_to_groups
2271 .iter()
2272 .find(|(_, g)| g.len() > 1)
2273 .map(|(input, groups)| {
2274 let short = input
2275 .iter()
2276 .take(8)
2277 .map(|b| format!("{:02x}", b))
2278 .collect::<String>();
2279 let mut group_strs: Vec<String> =
2280 groups.iter().map(|k| format!("{:?}", k)).collect();
2281 group_strs.sort();
2282 format!(
2283 "input {} shared by groups [{}]",
2284 short,
2285 group_strs.join(", ")
2286 )
2287 });
2288 let shared_variable_count =
2294 input_to_groups.values().filter(|g| g.len() > 1).count();
2295 warnings.push(RuntimeWarning {
2296 code: RuntimeWarningCode::CrossGroupCorrelationNotExact,
2297 message: format!(
2298 "Rule '{}': {} IS-ref base fact(s) are shared across different \
2299 KEY groups. BDD corrects per-group probabilities but cannot \
2300 account for cross-group correlations.",
2301 rule.name, shared_variable_count
2302 ),
2303 rule_name: rule.name.clone(),
2304 variable_count: Some(shared_variable_count),
2305 key_group: example,
2306 });
2307 }
2308 }
2309 }
2310
2311 if any_shared {
2312 let suppress_under_topk = matches!(semiring_kind, SemiringKind::TopKProofs { .. });
2322 if !suppress_under_topk && let Ok(mut warnings) = warnings_slot.write() {
2323 let already_warned = warnings.iter().any(|w| {
2324 w.code == RuntimeWarningCode::SharedProbabilisticDependency
2325 && w.rule_name == rule.name
2326 });
2327 if !already_warned {
2328 warnings.push(RuntimeWarning {
2329 code: RuntimeWarningCode::SharedProbabilisticDependency,
2330 message: format!(
2331 "Rule '{}' aggregates with MNOR/MPROD but some proof paths \
2332 share intermediate facts, violating the independence assumption. \
2333 Results may overestimate probability.",
2334 rule.name
2335 ),
2336 rule_name: rule.name.clone(),
2337 variable_count: None,
2338 key_group: None,
2339 });
2340 }
2341 }
2342 Some(SharedLineageInfo { shared_groups })
2343 } else {
2344 None
2345 }
2346}
2347
2348#[allow(
2356 clippy::too_many_arguments,
2357 reason = "context bundle would be over-engineering for one call site"
2358)]
2359pub(crate) async fn record_and_detect_lineage_nonrecursive(
2360 rule: &FixpointRulePlan,
2361 tagged_clause_facts: &[(usize, Vec<RecordBatch>)],
2362 tracker: &Arc<ProvenanceStore>,
2363 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2364 registry: &Arc<DerivedScanRegistry>,
2365 top_k_proofs: usize,
2366 classifiers: ClassifierRefs<'_>,
2367 semiring_kind: SemiringKind,
2368) -> Option<SharedLineageInfo> {
2369 let classifier_registry = classifiers.registry;
2370 let classifier_cache = classifiers.cache;
2371 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2372
2373 let base_probs = if top_k_proofs > 0 {
2375 tracker.base_fact_probs()
2376 } else {
2377 HashMap::new()
2378 };
2379
2380 let mut topk_acc = TopKProofAccumulator::new();
2381
2382 for (clause_index, batches) in tagged_clause_facts {
2384 for batch in batches {
2385 for row_idx in 0..batch.num_rows() {
2386 let row_hash = fact_hash_key(batch, &all_indices, row_idx);
2387 let fact_row = batch_row_to_value_map(batch, row_idx);
2388
2389 let support = collect_is_ref_inputs(rule, *clause_index, batch, row_idx, registry);
2390
2391 let proof_probability = if top_k_proofs > 0 {
2392 compute_proof_probability(&support, &base_probs)
2393 } else {
2394 None
2395 };
2396
2397 let entry = ProvenanceAnnotation {
2398 rule_name: rule.name.clone(),
2399 clause_index: *clause_index,
2400 support,
2401 along_values: {
2402 let along_names: Vec<String> = rule
2403 .clauses
2404 .get(*clause_index)
2405 .map(|c| c.along_bindings.clone())
2406 .unwrap_or_default();
2407 along_names
2408 .iter()
2409 .filter_map(|name| {
2410 fact_row.get(name).map(|v| (name.clone(), v.clone()))
2411 })
2412 .collect()
2413 },
2414 iteration: 0,
2415 fact_row: fact_row.clone(),
2416 proof_probability,
2417 neural_calls: collect_neural_calls_for_row(
2418 rule,
2419 *clause_index,
2420 &fact_row,
2421 classifier_registry,
2422 classifier_cache,
2423 classifiers.provenance_store,
2424 )
2425 .await,
2426 };
2427 if top_k_proofs > 0 {
2428 topk_acc.accumulate(&entry, &row_hash);
2429 tracker.record_top_k(row_hash, entry, top_k_proofs);
2430 } else {
2431 tracker.record(row_hash, entry);
2432 }
2433 }
2434 }
2435 }
2436
2437 topk_acc.emit_warning_if_any(rule, top_k_proofs, warnings_slot);
2438
2439 let all_facts: Vec<RecordBatch> = tagged_clause_facts
2441 .iter()
2442 .flat_map(|(_, batches)| batches.iter().cloned())
2443 .collect();
2444 detect_shared_lineage(rule, &all_facts, tracker, warnings_slot, semiring_kind)
2445}
2446
2447pub(crate) fn apply_exact_wmc(
2455 pre_fold_facts: Vec<RecordBatch>,
2456 rule: &FixpointRulePlan,
2457 shared_info: &SharedLineageInfo,
2458 tracker: &Arc<ProvenanceStore>,
2459 max_bdd_variables: usize,
2460 warnings_slot: &Arc<StdRwLock<Vec<RuntimeWarning>>>,
2461 approximate_slot: &Arc<StdRwLock<HashMap<String, Vec<String>>>>,
2462) -> DFResult<Vec<RecordBatch>> {
2463 use crate::query::df_graph::locy_bdd::{SemiringOp, weighted_model_count};
2464 use uni_locy::{RuntimeWarning, RuntimeWarningCode};
2465
2466 let prob_fold = rule
2470 .fold_bindings
2471 .iter()
2472 .find(|fb| fb.aggregate.is_probability_aggregate());
2473 let prob_fold = match prob_fold {
2474 Some(f) => f,
2475 None => return Ok(pre_fold_facts),
2476 };
2477 let semiring_op = if prob_fold.aggregate.is_noisy_or() {
2478 SemiringOp::Disjunction
2479 } else {
2480 SemiringOp::Conjunction
2481 };
2482 let prob_col_idx = prob_fold.input_col_index;
2483 let prob_col_name = rule.yield_schema.field(prob_col_idx).name().clone();
2484
2485 let key_indices = &rule.key_column_indices;
2486 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
2487
2488 let shared_keys: HashSet<Vec<ScalarKey>> = shared_info.shared_groups.keys().cloned().collect();
2490
2491 struct GroupAccum {
2494 base_facts: Vec<HashSet<Vec<u8>>>,
2495 base_probs: HashMap<Vec<u8>, f64>,
2496 representative: (usize, usize),
2498 row_locations: Vec<(usize, usize)>,
2499 }
2500
2501 let mut group_accums: HashMap<Vec<ScalarKey>, GroupAccum> = HashMap::new();
2502 let mut non_shared_rows: Vec<(usize, usize)> = Vec::new(); for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
2505 for row_idx in 0..batch.num_rows() {
2506 let key = extract_scalar_key(batch, key_indices, row_idx);
2507 if shared_keys.contains(&key) {
2508 let fact_hash = fact_hash_key(batch, &all_indices, row_idx);
2509 let bases = compute_lineage(&fact_hash, tracker, &mut HashSet::new());
2510
2511 let accum = group_accums.entry(key).or_insert_with(|| GroupAccum {
2512 base_facts: Vec::new(),
2513 base_probs: HashMap::new(),
2514 representative: (batch_idx, row_idx),
2515 row_locations: Vec::new(),
2516 });
2517
2518 for bf in &bases {
2520 if !accum.base_probs.contains_key(bf)
2521 && let Some(entry) = tracker.lookup(bf)
2522 && let Some(val) = entry.fact_row.get(&prob_col_name)
2523 && let Some(p) = value_to_f64(val)
2524 {
2525 accum.base_probs.insert(bf.clone(), p);
2526 }
2527 }
2528
2529 accum.base_facts.push(bases);
2530 accum.row_locations.push((batch_idx, row_idx));
2531 } else {
2532 non_shared_rows.push((batch_idx, row_idx));
2533 }
2534 }
2535 }
2536
2537 let mut keep_rows: HashSet<(usize, usize)> = HashSet::new();
2540 let mut overrides: HashMap<(usize, usize), f64> = HashMap::new();
2542
2543 for &loc in &non_shared_rows {
2545 keep_rows.insert(loc);
2546 }
2547
2548 for (key, accum) in &group_accums {
2549 let bdd_result = weighted_model_count(
2550 &accum.base_facts,
2551 &accum.base_probs,
2552 semiring_op,
2553 max_bdd_variables,
2554 );
2555
2556 if bdd_result.approximated {
2557 if let Ok(mut warnings) = warnings_slot.write() {
2559 let key_desc = format!("{key:?}");
2560 let already_warned = warnings.iter().any(|w| {
2561 w.code == RuntimeWarningCode::BddLimitExceeded
2562 && w.rule_name == rule.name
2563 && w.key_group.as_deref() == Some(&key_desc)
2564 });
2565 if !already_warned {
2566 warnings.push(RuntimeWarning {
2567 code: RuntimeWarningCode::BddLimitExceeded,
2568 message: format!(
2569 "Rule '{}': BDD variable limit exceeded ({} > {}). \
2570 Falling back to independence-mode result.",
2571 rule.name, bdd_result.variable_count, max_bdd_variables
2572 ),
2573 rule_name: rule.name.clone(),
2574 variable_count: Some(bdd_result.variable_count),
2575 key_group: Some(key_desc),
2576 });
2577 }
2578 }
2579 if let Ok(mut approx) = approximate_slot.write() {
2580 let key_desc = format!("{key:?}");
2581 approx.entry(rule.name.clone()).or_default().push(key_desc);
2582 }
2583 for &loc in &accum.row_locations {
2585 keep_rows.insert(loc);
2586 }
2587 } else {
2588 keep_rows.insert(accum.representative);
2590 overrides.insert(accum.representative, bdd_result.probability);
2591 }
2592 }
2593
2594 let mut result_batches = Vec::new();
2596 for (batch_idx, batch) in pre_fold_facts.iter().enumerate() {
2597 let kept_indices: Vec<usize> = (0..batch.num_rows())
2598 .filter(|&row_idx| keep_rows.contains(&(batch_idx, row_idx)))
2599 .collect();
2600
2601 if kept_indices.is_empty() {
2602 continue;
2603 }
2604
2605 let indices = arrow::array::UInt32Array::from(
2606 kept_indices.iter().map(|&i| i as u32).collect::<Vec<_>>(),
2607 );
2608 let mut columns: Vec<arrow::array::ArrayRef> = batch
2609 .columns()
2610 .iter()
2611 .map(|col| arrow::compute::take(col, &indices, None))
2612 .collect::<Result<Vec<_>, _>>()
2613 .map_err(arrow_err)?;
2614
2615 let override_map: Vec<Option<f64>> = kept_indices
2617 .iter()
2618 .map(|&row_idx| overrides.get(&(batch_idx, row_idx)).copied())
2619 .collect();
2620
2621 if override_map.iter().any(|o| o.is_some()) && prob_col_idx < columns.len() {
2622 let existing_prob = columns[prob_col_idx]
2624 .as_any()
2625 .downcast_ref::<arrow::array::Float64Array>();
2626 let new_values: Vec<f64> = override_map
2627 .iter()
2628 .enumerate()
2629 .map(|(i, ov)| match ov {
2630 Some(p) => *p,
2631 None => existing_prob.map(|arr| arr.value(i)).unwrap_or(0.0),
2632 })
2633 .collect();
2634 columns[prob_col_idx] = Arc::new(arrow::array::Float64Array::from(new_values));
2635 }
2636
2637 let result_batch = RecordBatch::try_new(batch.schema(), columns).map_err(arrow_err)?;
2638 result_batches.push(result_batch);
2639 }
2640
2641 Ok(result_batches)
2642}
2643
2644fn value_to_f64(val: &uni_common::Value) -> Option<f64> {
2646 match val {
2647 uni_common::Value::Float(f) => Some(*f),
2648 uni_common::Value::Int(i) => Some(*i as f64),
2649 _ => None,
2650 }
2651}
2652
2653fn compute_lineage(
2660 fact_hash: &[u8],
2661 tracker: &Arc<ProvenanceStore>,
2662 visited: &mut HashSet<Vec<u8>>,
2663) -> HashSet<Vec<u8>> {
2664 if !visited.insert(fact_hash.to_vec()) {
2665 return HashSet::new(); }
2667
2668 match tracker.lookup(fact_hash) {
2669 Some(entry) if !entry.support.is_empty() => {
2670 let mut bases = HashSet::new();
2671 for input in &entry.support {
2672 let child_bases = compute_lineage(&input.base_fact_id, tracker, visited);
2673 bases.extend(child_bases);
2674 }
2675 bases
2676 }
2677 _ => {
2678 let mut set = HashSet::new();
2680 set.insert(fact_hash.to_vec());
2681 set
2682 }
2683 }
2684}
2685
2686fn find_clause_for_row(
2691 delta_batch: &RecordBatch,
2692 row_idx: usize,
2693 all_indices: &[usize],
2694 clause_candidates: &[Vec<RecordBatch>],
2695) -> usize {
2696 let target_key = extract_scalar_key(delta_batch, all_indices, row_idx);
2697 for (clause_idx, batches) in clause_candidates.iter().enumerate() {
2698 for batch in batches {
2699 if batch.num_columns() != all_indices.len() {
2700 continue;
2701 }
2702 for r in 0..batch.num_rows() {
2703 if extract_scalar_key(batch, all_indices, r) == target_key {
2704 return clause_idx;
2705 }
2706 }
2707 }
2708 }
2709 0
2710}
2711
2712fn batch_row_to_value_map(
2714 batch: &RecordBatch,
2715 row_idx: usize,
2716) -> std::collections::HashMap<String, Value> {
2717 use uni_store::storage::arrow_convert::arrow_to_value;
2718
2719 let schema = batch.schema();
2720 schema
2721 .fields()
2722 .iter()
2723 .enumerate()
2724 .map(|(col_idx, field)| {
2725 let col = batch.column(col_idx);
2726 let val = arrow_to_value(col.as_ref(), row_idx, None);
2727 (field.name().clone(), val)
2728 })
2729 .collect()
2730}
2731
2732pub fn apply_anti_join(
2737 batches: Vec<RecordBatch>,
2738 neg_facts: &[RecordBatch],
2739 left_col: &str,
2740 right_col: &str,
2741) -> datafusion::error::Result<Vec<RecordBatch>> {
2742 use arrow::compute::filter_record_batch;
2743 use arrow_array::{Array as _, BooleanArray, UInt64Array};
2744
2745 let mut banned: std::collections::HashSet<u64> = std::collections::HashSet::new();
2747 for batch in neg_facts {
2748 let Ok(idx) = batch.schema().index_of(right_col) else {
2749 continue;
2750 };
2751 let arr = batch.column(idx);
2752 let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2753 continue;
2754 };
2755 for i in 0..vids.len() {
2756 if !vids.is_null(i) {
2757 banned.insert(vids.value(i));
2758 }
2759 }
2760 }
2761
2762 if banned.is_empty() {
2763 return Ok(batches);
2764 }
2765
2766 let mut result = Vec::new();
2768 for batch in batches {
2769 let Ok(idx) = batch.schema().index_of(left_col) else {
2770 result.push(batch);
2771 continue;
2772 };
2773 let arr = batch.column(idx);
2774 let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
2775 result.push(batch);
2776 continue;
2777 };
2778 let keep: Vec<bool> = (0..vids.len())
2779 .map(|i| vids.is_null(i) || !banned.contains(&vids.value(i)))
2780 .collect();
2781 let keep_arr = BooleanArray::from(keep);
2782 let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
2783 if filtered.num_rows() > 0 {
2784 result.push(filtered);
2785 }
2786 }
2787 Ok(result)
2788}
2789
2790#[allow(clippy::too_many_arguments)]
2811pub(crate) async fn apply_model_invocations(
2812 batches: Vec<RecordBatch>,
2813 invocations: &[uni_locy::ModelInvocation],
2814 registry: &Arc<ClassifierRegistry>,
2815 cache: Option<&Arc<uni_locy::ModelInvocationCache>>,
2816 provenance_store: Option<&Arc<uni_locy::NeuralProvenanceStore>>,
2817 path_context_handles: &HashMap<
2818 String,
2819 crate::query::df_graph::locy_model_invoke::PathContextHandle,
2820 >,
2821 xervo_runtime: &crate::query::df_graph::locy_model_invoke::XervoRuntimeHandle,
2822 graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
2823) -> DFResult<Vec<RecordBatch>> {
2824 use uni_locy::ClassifyInput;
2825 if batches.is_empty() || invocations.is_empty() {
2826 return Ok(batches);
2827 }
2828 let semantic_match_embeddings =
2832 pre_embed_semantic_match_queries(invocations, xervo_runtime).await?;
2833 let graph_feature_maps = precompute_graph_feature_maps(invocations, graph_algo).await?;
2838 let neighbor_feature_maps =
2839 precompute_neighbor_feature_maps(invocations, &batches, graph_algo).await?;
2840 let mut out_batches = Vec::with_capacity(batches.len());
2841 for batch in batches {
2842 let mut current = batch;
2843 for invocation in invocations {
2844 let classifier = registry.get(&invocation.model_name).ok_or_else(|| {
2845 datafusion::error::DataFusionError::Execution(format!(
2846 "neural classifier '{}' not registered; \
2847 add it to LocyConfig::classifier_registry",
2848 invocation.model_name
2849 ))
2850 })?;
2851
2852 let resolvers = build_feature_resolvers(
2864 ¤t,
2865 invocation,
2866 path_context_handles,
2867 &semantic_match_embeddings,
2868 &graph_feature_maps,
2869 &neighbor_feature_maps,
2870 )?;
2871
2872 let n_rows = current.num_rows();
2874 let mut inputs: Vec<ClassifyInput> = Vec::with_capacity(n_rows);
2875 let mut input_hashes: Vec<u64> = Vec::with_capacity(n_rows);
2876 for row_idx in 0..n_rows {
2877 let mut features = std::collections::HashMap::new();
2878 for resolver in &resolvers {
2879 let value = resolver.eval_row(¤t, row_idx)?;
2880 features.insert(resolver.binding_name.clone(), value);
2881 }
2882 let input = ClassifyInput { features };
2883 input_hashes.push(input.stable_hash());
2884 inputs.push(input);
2885 }
2886
2887 let mut probs: Vec<f64> = vec![0.0; n_rows];
2891 let mut miss_inputs: Vec<ClassifyInput> = Vec::new();
2892 let mut miss_row_indices: Vec<usize> = Vec::new();
2893 if let Some(c) = cache {
2894 for (row_idx, h) in input_hashes.iter().enumerate() {
2895 match c.get(&invocation.model_name, *h) {
2896 Some(v) => probs[row_idx] = v,
2897 None => {
2898 miss_row_indices.push(row_idx);
2899 miss_inputs.push(inputs[row_idx].clone());
2900 }
2901 }
2902 }
2903 } else {
2904 miss_row_indices = (0..n_rows).collect();
2905 miss_inputs = inputs.clone();
2906 }
2907
2908 let calibrator = classifier.get_calibrator();
2917 let (miss_raws, miss_calibrated) = if miss_inputs.is_empty() {
2918 (Vec::new(), Vec::new())
2919 } else if calibrator.is_some() {
2920 let pairs = classifier
2921 .raw_and_calibrated(&miss_inputs)
2922 .await
2923 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2924 if pairs.len() != miss_inputs.len() {
2925 return Err(datafusion::error::DataFusionError::Execution(format!(
2926 "classifier '{}' raw_and_calibrated returned {} outputs for {} inputs",
2927 invocation.model_name,
2928 pairs.len(),
2929 miss_inputs.len()
2930 )));
2931 }
2932 let raws: Vec<f64> = pairs.iter().map(|(r, _)| *r).collect();
2933 let cals: Vec<f64> = pairs.iter().map(|(r, c)| c.unwrap_or(*r)).collect();
2934 (raws, cals)
2935 } else {
2936 let r = classifier
2937 .classify(&miss_inputs)
2938 .await
2939 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
2940 if r.len() != miss_inputs.len() {
2941 return Err(datafusion::error::DataFusionError::Execution(format!(
2942 "classifier '{}' returned {} outputs for {} inputs",
2943 invocation.model_name,
2944 r.len(),
2945 miss_inputs.len()
2946 )));
2947 }
2948 (r.clone(), r)
2950 };
2951 let mut row_raw: Vec<Option<f64>> = vec![None; n_rows];
2961 for (i, &row_idx) in miss_row_indices.iter().enumerate() {
2962 probs[row_idx] = miss_calibrated[i];
2963 row_raw[row_idx] = Some(miss_raws[i]);
2964 if let Some(c) = cache {
2965 c.insert(
2966 &invocation.model_name,
2967 input_hashes[row_idx],
2968 miss_calibrated[i],
2969 );
2970 }
2971 }
2972
2973 if let Some(store) = provenance_store {
2981 for row_idx in 0..n_rows {
2982 let calibrated_value = probs[row_idx];
2983 let (raw_value, calibrated) = match (row_raw[row_idx], &calibrator) {
2984 (Some(raw), Some(_)) => (raw, Some(calibrated_value)),
2985 (Some(raw), None) => (raw, None),
2986 (None, _) => (
2991 calibrated_value,
2992 calibrator.as_ref().map(|_| calibrated_value),
2993 ),
2994 };
2995 let band = calibrator
2996 .as_ref()
2997 .and_then(|c| c.confidence_band(calibrated_value));
2998 store.record(
2999 &invocation.model_name,
3000 input_hashes[row_idx],
3001 uni_locy::NeuralProvenanceRecord {
3002 raw_probability: raw_value,
3003 calibrated_probability: calibrated,
3004 confidence_band: band,
3005 feature_inputs: inputs[row_idx].features.clone(),
3013 },
3014 );
3015 }
3016 }
3017
3018 let out_col: Arc<dyn arrow_array::Array> =
3023 Arc::new(arrow_array::Float64Array::from(probs));
3024 let schema = current.schema();
3025 let target_idx = schema.index_of(&invocation.output_column).ok();
3026 let mut columns: Vec<Arc<dyn arrow_array::Array>> = current.columns().to_vec();
3027 let mut fields: Vec<Arc<arrow_schema::Field>> =
3028 schema.fields().iter().cloned().collect();
3029 match target_idx {
3030 Some(idx) => {
3031 columns[idx] = out_col;
3032 fields[idx] = Arc::new(arrow_schema::Field::new(
3035 &invocation.output_column,
3036 arrow_schema::DataType::Float64,
3037 true,
3038 ));
3039 }
3040 None => {
3041 columns.push(out_col);
3042 fields.push(Arc::new(arrow_schema::Field::new(
3043 &invocation.output_column,
3044 arrow_schema::DataType::Float64,
3045 true,
3046 )));
3047 }
3048 }
3049 let new_schema = Arc::new(arrow_schema::Schema::new(fields));
3050 current = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
3051 }
3052 out_batches.push(current);
3053 }
3054 Ok(out_batches)
3055}
3056
3057struct FeatureResolver {
3065 binding_name: String,
3066 kind: FeatureResolverKind,
3067}
3068
3069enum FeatureResolverKind {
3070 Direct(usize),
3071 SimilarTo {
3072 left: FeatureValueSrc,
3073 right: FeatureValueSrc,
3074 },
3075 PathContext {
3080 subject_col: usize,
3081 vid_to_value: Arc<HashMap<u64, uni_locy::FeatureValue>>,
3082 },
3083 GraphAlgoScore {
3088 subject_col: usize,
3089 vid_to_score: Arc<HashMap<u64, f64>>,
3090 },
3091 NeighborAggregate {
3097 subject_col: usize,
3098 op: NeighborAgg,
3099 vid_to_values: Arc<HashMap<u64, Vec<f64>>>,
3100 },
3101}
3102
3103#[derive(Debug, Clone, Copy)]
3104enum NeighborAgg {
3105 Avg,
3106 Max,
3107 Sum,
3108}
3109
3110#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
3114enum NeighborDirection {
3115 Outgoing,
3116 Incoming,
3117 Both,
3118}
3119
3120impl NeighborDirection {
3121 fn store_directions(self) -> &'static [uni_store::storage::direction::Direction] {
3122 use uni_store::storage::direction::Direction;
3123 match self {
3124 NeighborDirection::Outgoing => &[Direction::Outgoing],
3125 NeighborDirection::Incoming => &[Direction::Incoming],
3126 NeighborDirection::Both => &[Direction::Outgoing, Direction::Incoming],
3127 }
3128 }
3129}
3130
3131impl NeighborAgg {
3132 fn from_fn_name(name: &str) -> Option<Self> {
3133 match name {
3134 "avg_neighbor" => Some(NeighborAgg::Avg),
3135 "max_neighbor" => Some(NeighborAgg::Max),
3136 "sum_neighbor" => Some(NeighborAgg::Sum),
3137 _ => None,
3138 }
3139 }
3140
3141 fn apply(self, values: &[f64]) -> Option<f64> {
3142 if values.is_empty() {
3143 return None;
3144 }
3145 match self {
3146 NeighborAgg::Avg => Some(values.iter().sum::<f64>() / values.len() as f64),
3147 NeighborAgg::Max => values.iter().copied().reduce(f64::max),
3148 NeighborAgg::Sum => Some(values.iter().sum()),
3149 }
3150 }
3151}
3152
3153enum FeatureValueSrc {
3156 Col(usize),
3157 Const(uni_common::Value),
3158}
3159
3160impl FeatureValueSrc {
3161 fn resolve(&self, batch: &RecordBatch, row_idx: usize) -> uni_common::Value {
3162 match self {
3163 FeatureValueSrc::Col(idx) => extract_common_value(batch.column(*idx).as_ref(), row_idx),
3164 FeatureValueSrc::Const(v) => v.clone(),
3165 }
3166 }
3167}
3168
3169impl FeatureResolver {
3170 fn eval_row(&self, batch: &RecordBatch, row_idx: usize) -> DFResult<uni_locy::FeatureValue> {
3171 match &self.kind {
3172 FeatureResolverKind::Direct(idx) => {
3173 Ok(extract_feature_value(batch.column(*idx).as_ref(), row_idx))
3174 }
3175 FeatureResolverKind::SimilarTo { left, right } => {
3176 let lv = left.resolve(batch, row_idx);
3177 let rv = right.resolve(batch, row_idx);
3178 match crate::query::similar_to::eval_similar_to_pure(&lv, &rv) {
3179 Ok(uni_common::Value::Float(f)) => Ok(uni_locy::FeatureValue::Float(f)),
3180 Ok(_) => Ok(uni_locy::FeatureValue::Null),
3181 Err(e) => Err(datafusion::error::DataFusionError::Execution(format!(
3182 "similar_to UDF failed: {e}"
3183 ))),
3184 }
3185 }
3186 FeatureResolverKind::PathContext {
3187 subject_col,
3188 vid_to_value,
3189 } => {
3190 let col = batch.column(*subject_col);
3191 if col.is_null(row_idx) {
3192 return Ok(uni_locy::FeatureValue::Null);
3193 }
3194 if let Some(arr) = col.as_any().downcast_ref::<arrow_array::UInt64Array>() {
3195 let vid = arr.value(row_idx);
3196 Ok(vid_to_value
3197 .get(&vid)
3198 .cloned()
3199 .unwrap_or(uni_locy::FeatureValue::Null))
3200 } else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3201 let vid = arr.value(row_idx) as u64;
3202 Ok(vid_to_value
3203 .get(&vid)
3204 .cloned()
3205 .unwrap_or(uni_locy::FeatureValue::Null))
3206 } else {
3207 Ok(uni_locy::FeatureValue::Null)
3208 }
3209 }
3210 FeatureResolverKind::GraphAlgoScore {
3211 subject_col,
3212 vid_to_score,
3213 } => {
3214 let col = batch.column(*subject_col);
3215 if col.is_null(row_idx) {
3216 return Ok(uni_locy::FeatureValue::Null);
3217 }
3218 let vid_opt: Option<u64> = if let Some(arr) =
3219 col.as_any().downcast_ref::<arrow_array::UInt64Array>()
3220 {
3221 Some(arr.value(row_idx))
3222 } else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3223 Some(arr.value(row_idx) as u64)
3224 } else {
3225 match extract_common_value(col.as_ref(), row_idx) {
3231 uni_common::Value::Node(n) => Some(n.vid.as_u64()),
3232 uni_common::Value::Int(i) => Some(i as u64),
3233 _ => None,
3234 }
3235 };
3236 Ok(vid_opt
3237 .and_then(|v| vid_to_score.get(&v).copied())
3238 .map(uni_locy::FeatureValue::Float)
3239 .unwrap_or(uni_locy::FeatureValue::Null))
3240 }
3241 FeatureResolverKind::NeighborAggregate {
3242 subject_col,
3243 op,
3244 vid_to_values,
3245 } => {
3246 let vid_opt = extract_vid_from_column(batch.column(*subject_col).as_ref(), row_idx);
3247 Ok(vid_opt
3248 .and_then(|v| vid_to_values.get(&v))
3249 .and_then(|values| op.apply(values))
3250 .map(uni_locy::FeatureValue::Float)
3251 .unwrap_or(uni_locy::FeatureValue::Null))
3252 }
3253 }
3254 }
3255}
3256
3257fn extract_vid_from_column(col: &dyn arrow_array::Array, row_idx: usize) -> Option<u64> {
3262 if col.is_null(row_idx) {
3263 return None;
3264 }
3265 if let Some(arr) = col.as_any().downcast_ref::<arrow_array::UInt64Array>() {
3266 return Some(arr.value(row_idx));
3267 }
3268 if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Int64Array>() {
3269 return Some(arr.value(row_idx) as u64);
3270 }
3271 match extract_common_value(col, row_idx) {
3272 uni_common::Value::Node(n) => Some(n.vid.as_u64()),
3273 uni_common::Value::Int(i) => Some(i as u64),
3274 _ => None,
3275 }
3276}
3277
3278#[allow(clippy::too_many_arguments)]
3279fn build_feature_resolvers(
3280 batch: &RecordBatch,
3281 invocation: &uni_locy::ModelInvocation,
3282 path_context_handles: &HashMap<
3283 String,
3284 crate::query::df_graph::locy_model_invoke::PathContextHandle,
3285 >,
3286 semantic_match_embeddings: &HashMap<String, Vec<f32>>,
3287 graph_feature_maps: &HashMap<String, Arc<HashMap<u64, f64>>>,
3288 neighbor_feature_maps: &NeighborFeatureMaps,
3289) -> DFResult<Vec<FeatureResolver>> {
3290 use uni_cypher::ast::Expr;
3291 let schema = batch.schema();
3292 let lookup_col = |name_or_property: String| -> DFResult<usize> {
3293 schema.index_of(&name_or_property).map_err(|_| {
3294 datafusion::error::DataFusionError::Execution(format!(
3295 "feature column '{name_or_property}' not found in clause body output schema"
3296 ))
3297 })
3298 };
3299 let resolve_src = |expr: &Expr| -> DFResult<FeatureValueSrc> {
3304 match expr {
3305 Expr::Variable(name) => {
3306 let col = if schema.index_of(name).is_ok() {
3307 name.clone()
3308 } else {
3309 let vid_name = format!("{}._vid", name);
3310 if schema.index_of(&vid_name).is_ok() {
3311 vid_name
3312 } else {
3313 name.clone()
3314 }
3315 };
3316 Ok(FeatureValueSrc::Col(lookup_col(col)?))
3317 }
3318 Expr::Property(boxed, prop) if matches!(boxed.as_ref(), Expr::Variable(_)) => {
3319 let Expr::Variable(v) = boxed.as_ref() else {
3320 unreachable!()
3321 };
3322 let direct = format!("{}.{}", v, prop);
3323 let col = if schema.index_of(&direct).is_ok() {
3324 direct
3325 } else {
3326 format!("__feat_{}_{}", v, prop)
3327 };
3328 Ok(FeatureValueSrc::Col(lookup_col(col)?))
3329 }
3330 Expr::Literal(lit) => Ok(FeatureValueSrc::Const(lit.to_value())),
3331 Expr::List(items) => {
3332 let mut out = Vec::with_capacity(items.len());
3333 for it in items {
3334 out.push(match it {
3335 Expr::Literal(lit) => lit.to_value(),
3336 _ => uni_common::Value::Null,
3337 });
3338 }
3339 Ok(FeatureValueSrc::Const(uni_common::Value::List(out)))
3340 }
3341 other => Err(datafusion::error::DataFusionError::Execution(format!(
3342 "unsupported feature sub-expression: {other:?}"
3343 ))),
3344 }
3345 };
3346
3347 if let Some(pc) = &invocation.path_context {
3355 let handle = path_context_handles.get(&pc.source_rule).ok_or_else(|| {
3356 datafusion::error::DataFusionError::Execution(format!(
3357 "model '{}' path_context references rule '{}' but no DerivedScanHandle \
3358 was registered; this should never happen — the build_clause path \
3359 mints a handle for every distinct source_rule in the invocation set",
3360 invocation.model_name, pc.source_rule
3361 ))
3362 })?;
3363 let subject_col = schema
3364 .index_of(&format!("{}._vid", pc.subject_var))
3365 .or_else(|_| schema.index_of(&pc.subject_var))
3366 .map_err(|_| {
3367 datafusion::error::DataFusionError::Execution(format!(
3368 "model '{}' path_context: subject column '{}' (or '{0}._vid') not \
3369 in body batch schema",
3370 invocation.model_name, pc.subject_var
3371 ))
3372 })?;
3373 let vid_to_value =
3374 build_path_context_lookup(handle, &pc.subject_var, &pc.column, &invocation.model_name)?;
3375 return Ok(vec![FeatureResolver {
3376 binding_name: pc.column.clone(),
3377 kind: FeatureResolverKind::PathContext {
3378 subject_col,
3379 vid_to_value: Arc::new(vid_to_value),
3380 },
3381 }]);
3382 }
3383
3384 let mut out = Vec::with_capacity(invocation.feature_exprs.len());
3385 for (i, fexpr) in invocation.feature_exprs.iter().enumerate() {
3386 let binding_name = invocation.feature_names[i].clone();
3387 let kind = match fexpr {
3388 Expr::FunctionCall { name, args, .. } if name == "similar_to" => {
3389 if args.len() != 2 {
3390 return Err(datafusion::error::DataFusionError::Execution(format!(
3391 "similar_to expects 2 args, got {}",
3392 args.len()
3393 )));
3394 }
3395 FeatureResolverKind::SimilarTo {
3396 left: resolve_src(&args[0])?,
3397 right: resolve_src(&args[1])?,
3398 }
3399 }
3400 Expr::FunctionCall { name, args, .. } if name == "semantic_match" => {
3401 if args.len() != 2 {
3406 return Err(datafusion::error::DataFusionError::Execution(format!(
3407 "semantic_match expects 2 args, got {}",
3408 args.len()
3409 )));
3410 }
3411 let text = match &args[1] {
3412 Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3413 other => {
3414 return Err(datafusion::error::DataFusionError::Execution(format!(
3415 "semantic_match: 2nd arg must be a string literal, got {other:?}"
3416 )));
3417 }
3418 };
3419 let embedded = semantic_match_embeddings.get(&text).ok_or_else(|| {
3420 datafusion::error::DataFusionError::Execution(format!(
3421 "semantic_match: query text '{text}' was not pre-embedded. \
3422 This is a bug — `apply_model_invocations` should have \
3423 embedded all unique semantic_match texts up front. Most \
3424 likely the Xervo runtime is not configured (configure \
3425 via `LocyConfig::xervo_runtime` or its equivalent)."
3426 ))
3427 })?;
3428 let right_vec: Vec<f32> = embedded.clone();
3429 FeatureResolverKind::SimilarTo {
3430 left: resolve_src(&args[0])?,
3431 right: FeatureValueSrc::Const(uni_common::Value::Vector(right_vec)),
3432 }
3433 }
3434 Expr::FunctionCall { name, args, .. }
3435 if matches!(
3436 name.as_str(),
3437 "degree_centrality"
3438 | "pagerank_score"
3439 | "closeness_centrality"
3440 | "betweenness_centrality"
3441 | "eigenvector_centrality"
3442 | "harmonic_centrality"
3443 | "katz_centrality"
3444 ) =>
3445 {
3446 if args.len() != 1 {
3447 return Err(datafusion::error::DataFusionError::Execution(format!(
3448 "{name} expects 1 arg, got {}",
3449 args.len()
3450 )));
3451 }
3452 let Expr::Variable(v) = &args[0] else {
3453 return Err(datafusion::error::DataFusionError::Execution(format!(
3454 "{name}(...) argument must be a node variable, got {:?}",
3455 args[0]
3456 )));
3457 };
3458 let subject_col = {
3459 let direct = schema.index_of(v).ok();
3460 let vid_name = format!("{}._vid", v);
3461 let vid_col = schema.index_of(&vid_name).ok();
3462 vid_col.or(direct).ok_or_else(|| {
3463 datafusion::error::DataFusionError::Execution(format!(
3464 "{name}: subject column '{v}' (or '{v}._vid') not in body batch schema"
3465 ))
3466 })?
3467 };
3468 let vid_to_score = graph_feature_maps.get(name).cloned().ok_or_else(|| {
3469 datafusion::error::DataFusionError::Execution(format!(
3470 "{name}: pre-computed score map missing. This is a bug — \
3471 `apply_model_invocations` should have called \
3472 `precompute_graph_feature_maps` for every graph-structural \
3473 feature before building resolvers. Most likely the graph \
3474 algorithm registry is not configured."
3475 ))
3476 })?;
3477 FeatureResolverKind::GraphAlgoScore {
3478 subject_col,
3479 vid_to_score,
3480 }
3481 }
3482 Expr::FunctionCall { name, args, .. }
3483 if matches!(
3484 name.as_str(),
3485 "avg_neighbor" | "max_neighbor" | "sum_neighbor"
3486 ) =>
3487 {
3488 if args.len() != 3 && args.len() != 4 {
3489 return Err(datafusion::error::DataFusionError::Execution(format!(
3490 "{name} expects 3 or 4 args, got {}",
3491 args.len()
3492 )));
3493 }
3494 let Expr::Variable(v) = &args[0] else {
3495 return Err(datafusion::error::DataFusionError::Execution(format!(
3496 "{name}(...) first argument must be a node variable, got {:?}",
3497 args[0]
3498 )));
3499 };
3500 let rel_type = match &args[1] {
3501 Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3502 other => {
3503 return Err(datafusion::error::DataFusionError::Execution(format!(
3504 "{name}: 2nd arg must be a string literal (rel-type), got {other:?}"
3505 )));
3506 }
3507 };
3508 let prop_name = match &args[2] {
3509 Expr::Literal(uni_cypher::ast::CypherLiteral::String(s)) => s.clone(),
3510 other => {
3511 return Err(datafusion::error::DataFusionError::Execution(format!(
3512 "{name}: 3rd arg must be a string literal (property), got {other:?}"
3513 )));
3514 }
3515 };
3516 let direction_arg = match args.get(3) {
3517 None => NeighborDirection::Outgoing,
3518 Some(Expr::Literal(uni_cypher::ast::CypherLiteral::String(d))) => {
3519 match d.to_uppercase().as_str() {
3520 "OUTGOING" => NeighborDirection::Outgoing,
3521 "INCOMING" => NeighborDirection::Incoming,
3522 "BOTH" => NeighborDirection::Both,
3523 other => {
3524 return Err(datafusion::error::DataFusionError::Execution(
3525 format!(
3526 "{name}: direction must be OUTGOING|INCOMING|BOTH, got '{other}'"
3527 ),
3528 ));
3529 }
3530 }
3531 }
3532 Some(other) => {
3533 return Err(datafusion::error::DataFusionError::Execution(format!(
3534 "{name}: 4th arg must be a string literal (direction), got {other:?}"
3535 )));
3536 }
3537 };
3538 let subject_col = {
3539 let direct = schema.index_of(v).ok();
3540 let vid_name = format!("{}._vid", v);
3541 let vid_col = schema.index_of(&vid_name).ok();
3542 vid_col.or(direct).ok_or_else(|| {
3543 datafusion::error::DataFusionError::Execution(format!(
3544 "{name}: subject column '{v}' (or '{v}._vid') not in body batch schema"
3545 ))
3546 })?
3547 };
3548 let vid_to_values = neighbor_feature_maps
3549 .get(&(rel_type.clone(), prop_name.clone(), direction_arg))
3550 .cloned()
3551 .ok_or_else(|| {
3552 datafusion::error::DataFusionError::Execution(format!(
3553 "{name}: pre-computed neighbor map missing for ({rel_type}, {prop_name}, {direction_arg:?}). \
3554 This is a bug — `apply_model_invocations` should have called \
3555 `precompute_neighbor_feature_maps` for every neighbor-aggregator \
3556 feature before building resolvers."
3557 ))
3558 })?;
3559 let op = NeighborAgg::from_fn_name(name).unwrap();
3560 FeatureResolverKind::NeighborAggregate {
3561 subject_col,
3562 op,
3563 vid_to_values,
3564 }
3565 }
3566 other => match resolve_src(other)? {
3567 FeatureValueSrc::Col(idx) => FeatureResolverKind::Direct(idx),
3568 FeatureValueSrc::Const(_) => {
3569 return Err(datafusion::error::DataFusionError::Execution(format!(
3570 "model '{}' feature must reference a variable or property — got a literal",
3571 invocation.model_name
3572 )));
3573 }
3574 },
3575 };
3576 out.push(FeatureResolver { binding_name, kind });
3577 }
3578 Ok(out)
3579}
3580
3581async fn pre_embed_semantic_match_queries(
3588 invocations: &[uni_locy::ModelInvocation],
3589 xervo_runtime: &crate::query::df_graph::locy_model_invoke::XervoRuntimeHandle,
3590) -> DFResult<HashMap<String, Vec<f32>>> {
3591 use uni_cypher::ast::{CypherLiteral, Expr};
3592 let mut needed: Vec<(String, String)> = Vec::new();
3602 for inv in invocations {
3603 let alias = inv
3604 .embedder_alias
3605 .clone()
3606 .unwrap_or_else(|| "default".to_string());
3607 for fexpr in &inv.feature_exprs {
3608 if let Expr::FunctionCall { name, args, .. } = fexpr
3609 && name == "semantic_match"
3610 && args.len() == 2
3611 && let Expr::Literal(CypherLiteral::String(s)) = &args[1]
3612 {
3613 let tuple = (s.clone(), alias.clone());
3614 if !needed.contains(&tuple) {
3615 needed.push(tuple);
3616 }
3617 }
3618 }
3619 }
3620 if needed.is_empty() {
3621 return Ok(HashMap::new());
3622 }
3623 let runtime = xervo_runtime.as_ref().ok_or_else(|| {
3624 datafusion::error::DataFusionError::Execution(
3625 "semantic_match: Uni-Xervo runtime not configured. Either provide \
3626 one via `LocyConfig::xervo_runtime` (or its equivalent setup \
3627 path) or pre-compute the query embedding and pass it via \
3628 `similar_to(prop, <literal_vector>)`."
3629 .to_string(),
3630 )
3631 })?;
3632 let mut by_alias: HashMap<String, Vec<String>> = HashMap::new();
3635 for (text, alias) in &needed {
3636 by_alias
3637 .entry(alias.clone())
3638 .or_default()
3639 .push(text.clone());
3640 }
3641 let mut out: HashMap<String, Vec<f32>> = HashMap::new();
3642 for (alias, texts) in by_alias {
3643 let embedder = runtime.embedding(&alias).await.map_err(|e| {
3644 datafusion::error::DataFusionError::Execution(format!(
3645 "semantic_match: failed to obtain embedder for alias '{alias}': {e}. \
3646 Register an embedder under that alias in your Uni-Xervo runtime, or \
3647 pre-compute the query embedding and pass via similar_to."
3648 ))
3649 })?;
3650 let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
3651 let embeddings = embedder.embed(text_refs).await.map_err(|e| {
3652 datafusion::error::DataFusionError::Execution(format!(
3653 "semantic_match: embedder '{alias}' call failed: {e}"
3654 ))
3655 })?;
3656 if embeddings.len() != texts.len() {
3657 return Err(datafusion::error::DataFusionError::Execution(format!(
3658 "semantic_match: embedder '{alias}' returned {} vectors for {} queries",
3659 embeddings.len(),
3660 texts.len()
3661 )));
3662 }
3663 for (text, vec) in texts.into_iter().zip(embeddings) {
3664 out.insert(text, vec);
3665 }
3666 }
3667 Ok(out)
3668}
3669
3670async fn precompute_graph_feature_maps(
3683 invocations: &[uni_locy::ModelInvocation],
3684 graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
3685) -> DFResult<HashMap<String, Arc<HashMap<u64, f64>>>> {
3686 use futures::StreamExt;
3687 use uni_algo::algo::procedures::AlgoContext;
3688 use uni_cypher::ast::Expr;
3689
3690 fn procedure_for(fn_name: &str) -> Option<&'static str> {
3693 match fn_name {
3694 "degree_centrality" => Some("uni.algo.degreeCentrality"),
3695 "pagerank_score" => Some("uni.algo.pageRank"),
3696 "closeness_centrality" => Some("uni.algo.closeness"),
3697 "betweenness_centrality" => Some("uni.algo.betweenness"),
3698 "eigenvector_centrality" => Some("uni.algo.eigenvectorCentrality"),
3699 "harmonic_centrality" => Some("uni.algo.harmonicCentrality"),
3700 "katz_centrality" => Some("uni.algo.katzCentrality"),
3701 _ => None,
3702 }
3703 }
3704
3705 let mut needed: Vec<String> = Vec::new();
3709 for inv in invocations {
3710 for fexpr in &inv.feature_exprs {
3711 if let Expr::FunctionCall { name, .. } = fexpr
3712 && procedure_for(name).is_some()
3713 && !needed.contains(name)
3714 {
3715 needed.push(name.clone());
3716 }
3717 }
3718 }
3719 if needed.is_empty() {
3720 return Ok(HashMap::new());
3721 }
3722
3723 let registry = graph_algo.registry.as_ref().ok_or_else(|| {
3724 datafusion::error::DataFusionError::Execution(
3725 "graph-structural FEATURE invoked but no `AlgorithmRegistry` is \
3726 configured. Configure one on `GraphExecutionContext::with_algo_registry`."
3727 .to_string(),
3728 )
3729 })?;
3730 let storage = graph_algo.storage.as_ref().ok_or_else(|| {
3731 datafusion::error::DataFusionError::Execution(
3732 "graph-structural FEATURE invoked but no storage handle was \
3733 threaded into the FEATURE runtime. This is a bug in df_planner."
3734 .to_string(),
3735 )
3736 })?;
3737
3738 let mut out: HashMap<String, Arc<HashMap<u64, f64>>> = HashMap::new();
3739 for fn_name in needed {
3740 let proc_name = procedure_for(&fn_name).unwrap();
3741 let procedure = registry.get(proc_name).ok_or_else(|| {
3742 datafusion::error::DataFusionError::Execution(format!(
3743 "graph-structural FEATURE '{fn_name}' resolves to procedure \
3744 '{proc_name}' which is not in the algorithm registry"
3745 ))
3746 })?;
3747 let args: Vec<serde_json::Value> = vec![
3752 serde_json::Value::Array(Vec::new()),
3753 serde_json::Value::Array(Vec::new()),
3754 ];
3755 let algo_ctx = AlgoContext::new(
3756 storage.clone(),
3757 graph_algo.l0_manager.as_ref().map(Arc::clone),
3758 );
3759 let filled_args = procedure
3780 .signature()
3781 .validate_args(args.clone())
3782 .map_err(|e| {
3783 datafusion::error::DataFusionError::Execution(format!(
3784 "graph-structural FEATURE '{fn_name}': argument validation failed: {e}"
3785 ))
3786 })?;
3787 let projection = uni_algo::algo::procedure_template::build_projection_from_direct_args(
3788 procedure.as_ref(),
3789 &algo_ctx,
3790 &filled_args,
3791 )
3792 .await
3793 .map_err(|e| {
3794 datafusion::error::DataFusionError::Execution(format!(
3795 "graph-structural FEATURE '{fn_name}': projection build failed: {e}"
3796 ))
3797 })?;
3798 let mut stream = procedure.execute_with_projection(algo_ctx, args, projection);
3799 let mut score_map: HashMap<u64, f64> = HashMap::new();
3800 let sig = procedure.signature();
3801 let node_idx = sig
3802 .yields
3803 .iter()
3804 .position(|(n, _)| *n == "nodeId")
3805 .ok_or_else(|| {
3806 datafusion::error::DataFusionError::Execution(format!(
3807 "procedure '{proc_name}' yield schema missing 'nodeId'"
3808 ))
3809 })?;
3810 let score_idx = sig
3815 .yields
3816 .iter()
3817 .position(|(n, _)| *n == "score" || *n == "centrality")
3818 .ok_or_else(|| {
3819 datafusion::error::DataFusionError::Execution(format!(
3820 "procedure '{proc_name}' yield schema missing a numeric score column \
3821 (expected 'score' or 'centrality')"
3822 ))
3823 })?;
3824 while let Some(row_res) = stream.next().await {
3825 let row = row_res.map_err(|e| {
3826 datafusion::error::DataFusionError::Execution(format!(
3827 "graph-structural FEATURE '{fn_name}': procedure '{proc_name}' failed: {e}"
3828 ))
3829 })?;
3830 let vid_v = row.values.get(node_idx);
3831 let score_v = row.values.get(score_idx);
3832 let (Some(vid_v), Some(score_v)) = (vid_v, score_v) else {
3833 continue;
3834 };
3835 let vid = vid_v.as_u64().or_else(|| vid_v.as_i64().map(|i| i as u64));
3836 let score = score_v
3837 .as_f64()
3838 .or_else(|| score_v.as_i64().map(|i| i as f64));
3839 if let (Some(vid), Some(score)) = (vid, score) {
3840 score_map.insert(vid, score);
3841 }
3842 }
3843 out.insert(fn_name, Arc::new(score_map));
3844 }
3845 Ok(out)
3846}
3847
3848type NeighborFeatureMaps =
3874 HashMap<(String, String, NeighborDirection), Arc<HashMap<u64, Vec<f64>>>>;
3875
3876async fn precompute_neighbor_feature_maps(
3877 invocations: &[uni_locy::ModelInvocation],
3878 batches: &[RecordBatch],
3879 graph_algo: &crate::query::df_graph::locy_model_invoke::GraphAlgoHandle,
3880) -> DFResult<NeighborFeatureMaps> {
3881 use uni_cypher::ast::{CypherLiteral, Expr};
3882
3883 let parse_direction = |arg: Option<&Expr>| -> Option<NeighborDirection> {
3888 match arg {
3889 None => Some(NeighborDirection::Outgoing),
3890 Some(Expr::Literal(CypherLiteral::String(d))) => match d.to_uppercase().as_str() {
3891 "OUTGOING" => Some(NeighborDirection::Outgoing),
3892 "INCOMING" => Some(NeighborDirection::Incoming),
3893 "BOTH" => Some(NeighborDirection::Both),
3894 _ => None,
3895 },
3896 _ => None,
3897 }
3898 };
3899 let mut needed: Vec<(String, String, String, NeighborDirection)> = Vec::new();
3900 for inv in invocations {
3901 for fexpr in &inv.feature_exprs {
3902 if let Expr::FunctionCall { name, args, .. } = fexpr
3903 && NeighborAgg::from_fn_name(name).is_some()
3904 && (args.len() == 3 || args.len() == 4)
3905 && let Expr::Variable(v) = &args[0]
3906 && let Expr::Literal(CypherLiteral::String(rel)) = &args[1]
3907 && let Expr::Literal(CypherLiteral::String(prop)) = &args[2]
3908 && let Some(direction) = parse_direction(args.get(3))
3909 {
3910 let tuple = (v.clone(), rel.clone(), prop.clone(), direction);
3911 if !needed.contains(&tuple) {
3912 needed.push(tuple);
3913 }
3914 }
3915 }
3916 }
3917 if needed.is_empty() {
3918 return Ok(HashMap::new());
3919 }
3920
3921 let storage = graph_algo.storage.as_ref().ok_or_else(|| {
3922 datafusion::error::DataFusionError::Execution(
3923 "neighbor-aggregator FEATURE invoked but no storage handle was \
3924 threaded into the FEATURE runtime. This is a bug in df_planner."
3925 .to_string(),
3926 )
3927 })?;
3928 let property_manager = graph_algo.property_manager.as_ref().ok_or_else(|| {
3929 datafusion::error::DataFusionError::Execution(
3930 "neighbor-aggregator FEATURE invoked but no PropertyManager was \
3931 threaded into the FEATURE runtime. This is a bug in df_planner."
3932 .to_string(),
3933 )
3934 })?;
3935 let query_ctx = graph_algo.l0_buffers.as_ref().map(|bufs| {
3941 uni_store::runtime::context::QueryContext::new_with_pending(
3942 bufs.current.clone(),
3943 bufs.transaction.clone(),
3944 bufs.pending_flush.clone(),
3945 )
3946 });
3947
3948 let mut by_key: HashMap<(String, String, NeighborDirection), Vec<String>> = HashMap::new();
3952 for (subject_var, rel, prop, direction) in needed {
3953 by_key
3954 .entry((rel, prop, direction))
3955 .or_default()
3956 .push(subject_var);
3957 }
3958
3959 let mut out: NeighborFeatureMaps = HashMap::new();
3960 for ((rel_type, prop_name, direction), subject_vars) in by_key {
3961 let schema = storage.schema_manager().schema();
3963 let Some(edge_meta) = schema.edge_types.get(&rel_type) else {
3964 out.insert((rel_type, prop_name, direction), Arc::new(HashMap::new()));
3967 continue;
3968 };
3969 let edge_type_id = edge_meta.id;
3970
3971 let edge_ver = storage.get_edge_version_by_id(edge_type_id);
3974 for dir in direction.store_directions() {
3975 storage
3976 .warm_adjacency(edge_type_id, *dir, edge_ver)
3977 .await
3978 .map_err(|e| {
3979 datafusion::error::DataFusionError::Execution(format!(
3980 "neighbor-aggregator warm_adjacency for '{rel_type}' / {dir:?} failed: {e}"
3981 ))
3982 })?;
3983 }
3984
3985 let mut subject_vids: std::collections::HashSet<u64> = std::collections::HashSet::new();
3988 for subject_var in &subject_vars {
3989 for batch in batches {
3990 let schema = batch.schema();
3991 let col_idx = schema
3992 .index_of(&format!("{}._vid", subject_var))
3993 .ok()
3994 .or_else(|| schema.index_of(subject_var).ok());
3995 let Some(col_idx) = col_idx else { continue };
3996 let col = batch.column(col_idx);
3997 for row in 0..batch.num_rows() {
3998 if let Some(v) = extract_vid_from_column(col.as_ref(), row) {
3999 subject_vids.insert(v);
4000 }
4001 }
4002 }
4003 }
4004
4005 let mut vid_to_values: HashMap<u64, Vec<f64>> = HashMap::new();
4010 let adj = storage.adjacency_manager();
4011 for subject_vid in subject_vids {
4012 let mut neighbors: Vec<(uni_common::core::id::Vid, uni_common::core::id::Eid)> =
4013 Vec::new();
4014 for dir in direction.store_directions() {
4015 neighbors.extend(adj.get_neighbors(
4016 uni_common::core::id::Vid::from(subject_vid),
4017 edge_type_id,
4018 *dir,
4019 ));
4020 }
4021 let mut values: Vec<f64> = Vec::with_capacity(neighbors.len());
4022 for (neighbor_vid, _eid) in neighbors {
4023 let val = property_manager
4024 .get_vertex_prop_with_ctx(neighbor_vid, &prop_name, query_ctx.as_ref())
4025 .await
4026 .map_err(|e| {
4027 datafusion::error::DataFusionError::Execution(format!(
4028 "neighbor-aggregator: failed to read property \
4029 '{prop_name}' on neighbor vid {neighbor_vid:?}: {e}"
4030 ))
4031 })?;
4032 if let Some(f) = val.as_f64()
4033 && !f.is_nan()
4034 {
4035 values.push(f);
4036 }
4037 }
4038 vid_to_values.insert(subject_vid, values);
4039 }
4040 out.insert((rel_type, prop_name, direction), Arc::new(vid_to_values));
4041 }
4042 Ok(out)
4043}
4044
4045fn build_path_context_lookup(
4051 handle: &crate::query::df_graph::locy_model_invoke::PathContextHandle,
4052 _subject_var: &str,
4053 column: &str,
4054 model_name: &str,
4055) -> DFResult<HashMap<u64, uni_locy::FeatureValue>> {
4056 if handle.schema.fields().is_empty() {
4061 return Err(datafusion::error::DataFusionError::Execution(format!(
4062 "model '{model_name}' path_context: source rule has empty yield schema"
4063 )));
4064 }
4065 let subj_idx = 0_usize;
4066 let col_idx = handle.schema.index_of(column).map_err(|_| {
4067 datafusion::error::DataFusionError::Execution(format!(
4068 "model '{model_name}' path_context: column '{column}' not in \
4069 source rule's yield schema (have: {:?})",
4070 handle
4071 .schema
4072 .fields()
4073 .iter()
4074 .map(|f| f.name().clone())
4075 .collect::<Vec<_>>()
4076 ))
4077 })?;
4078 let batches = handle.data.read();
4079 let mut out: HashMap<u64, uni_locy::FeatureValue> = HashMap::new();
4080 for batch in batches.iter() {
4081 let subj_col = batch.column(subj_idx);
4082 let value_col = batch.column(col_idx);
4083 for row in 0..batch.num_rows() {
4084 if subj_col.is_null(row) {
4085 continue;
4086 }
4087 let vid = if let Some(a) = subj_col.as_any().downcast_ref::<arrow_array::UInt64Array>()
4088 {
4089 a.value(row)
4090 } else if let Some(a) = subj_col.as_any().downcast_ref::<arrow_array::Int64Array>() {
4091 a.value(row) as u64
4092 } else {
4093 continue;
4094 };
4095 let v = extract_feature_value(value_col.as_ref(), row);
4096 out.insert(vid, v);
4099 }
4100 }
4101 Ok(out)
4102}
4103
4104fn extract_common_value(col: &dyn arrow_array::Array, row_idx: usize) -> uni_common::Value {
4109 use arrow_array::{BooleanArray, Float64Array, Int64Array, LargeStringArray, StringArray};
4110 if col.is_null(row_idx) {
4111 return uni_common::Value::Null;
4112 }
4113 if let Some(a) = col.as_any().downcast_ref::<Float64Array>() {
4114 return uni_common::Value::Float(a.value(row_idx));
4115 }
4116 if let Some(a) = col.as_any().downcast_ref::<Int64Array>() {
4117 return uni_common::Value::Int(a.value(row_idx));
4118 }
4119 if let Some(a) = col.as_any().downcast_ref::<BooleanArray>() {
4120 return uni_common::Value::Bool(a.value(row_idx));
4121 }
4122 if let Some(a) = col.as_any().downcast_ref::<StringArray>() {
4123 return uni_common::Value::String(a.value(row_idx).to_string());
4124 }
4125 if let Some(a) = col.as_any().downcast_ref::<LargeStringArray>() {
4126 return uni_common::Value::String(a.value(row_idx).to_string());
4127 }
4128 if let Some(b) = col.as_any().downcast_ref::<arrow_array::LargeBinaryArray>() {
4129 let bytes = b.value(row_idx);
4130 if bytes.is_empty() {
4131 return uni_common::Value::Null;
4132 }
4133 return uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null);
4134 }
4135 uni_common::Value::Null
4136}
4137
4138fn extract_feature_value(col: &dyn arrow_array::Array, row_idx: usize) -> uni_locy::FeatureValue {
4139 use arrow_array::{BooleanArray, Float64Array, Int64Array, LargeStringArray, StringArray};
4140 if col.is_null(row_idx) {
4141 return uni_locy::FeatureValue::Null;
4142 }
4143 if let Some(a) = col.as_any().downcast_ref::<Float64Array>() {
4144 return uni_locy::FeatureValue::Float(a.value(row_idx));
4145 }
4146 if let Some(a) = col.as_any().downcast_ref::<Int64Array>() {
4147 return uni_locy::FeatureValue::Int(a.value(row_idx));
4148 }
4149 if let Some(a) = col.as_any().downcast_ref::<BooleanArray>() {
4150 return uni_locy::FeatureValue::Bool(a.value(row_idx));
4151 }
4152 if let Some(a) = col.as_any().downcast_ref::<StringArray>() {
4153 return uni_locy::FeatureValue::String(a.value(row_idx).to_string());
4154 }
4155 if let Some(a) = col.as_any().downcast_ref::<LargeStringArray>() {
4156 return uni_locy::FeatureValue::String(a.value(row_idx).to_string());
4157 }
4158 if let Some(b) = col.as_any().downcast_ref::<arrow_array::LargeBinaryArray>() {
4162 let bytes = b.value(row_idx);
4163 if bytes.is_empty() {
4164 return uni_locy::FeatureValue::Null;
4165 }
4166 let v = uni_common::cypher_value_codec::decode(bytes).unwrap_or(uni_common::Value::Null);
4167 return match v {
4168 uni_common::Value::Float(f) => uni_locy::FeatureValue::Float(f),
4169 uni_common::Value::Int(i) => uni_locy::FeatureValue::Int(i),
4170 uni_common::Value::Bool(b) => uni_locy::FeatureValue::Bool(b),
4171 uni_common::Value::String(s) => uni_locy::FeatureValue::String(s),
4172 uni_common::Value::Null => uni_locy::FeatureValue::Null,
4173 _ => uni_locy::FeatureValue::Null,
4174 };
4175 }
4176 uni_locy::FeatureValue::Null
4177}
4178
4179pub fn apply_prob_complement(
4186 batches: Vec<RecordBatch>,
4187 neg_facts: &[RecordBatch],
4188 left_col: &str,
4189 right_col: &str,
4190 prob_col: &str,
4191 complement_col_name: &str,
4192) -> datafusion::error::Result<Vec<RecordBatch>> {
4193 use arrow_array::{Array as _, Float64Array, UInt64Array};
4194
4195 let mut prob_map: std::collections::HashMap<u64, f64> = std::collections::HashMap::new();
4197 for batch in neg_facts {
4198 let Ok(vid_idx) = batch.schema().index_of(right_col) else {
4199 continue;
4200 };
4201 let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
4202 continue;
4203 };
4204 let Some(vids) = batch.column(vid_idx).as_any().downcast_ref::<UInt64Array>() else {
4205 continue;
4206 };
4207 let prob_arr = batch.column(prob_idx);
4208 let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
4209 for i in 0..vids.len() {
4210 if !vids.is_null(i) {
4211 let p = probs
4212 .and_then(|arr| {
4213 if arr.is_null(i) {
4214 None
4215 } else {
4216 Some(arr.value(i))
4217 }
4218 })
4219 .unwrap_or(0.0);
4220 prob_map
4223 .entry(vids.value(i))
4224 .and_modify(|existing| {
4225 *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
4226 })
4227 .or_insert(p);
4228 }
4229 }
4230 }
4231
4232 let mut result = Vec::new();
4234 for batch in batches {
4235 let Ok(idx) = batch.schema().index_of(left_col) else {
4236 result.push(batch);
4237 continue;
4238 };
4239 let Some(vids) = batch.column(idx).as_any().downcast_ref::<UInt64Array>() else {
4240 result.push(batch);
4241 continue;
4242 };
4243
4244 let complements: Vec<f64> = (0..vids.len())
4246 .map(|i| {
4247 if vids.is_null(i) {
4248 1.0
4249 } else {
4250 let p = prob_map.get(&vids.value(i)).copied().unwrap_or(0.0);
4251 1.0 - p
4252 }
4253 })
4254 .collect();
4255
4256 let complement_arr = Float64Array::from(complements);
4257
4258 let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
4260 columns.push(std::sync::Arc::new(complement_arr));
4261
4262 let mut fields: Vec<std::sync::Arc<arrow_schema::Field>> =
4263 batch.schema().fields().iter().cloned().collect();
4264 fields.push(std::sync::Arc::new(arrow_schema::Field::new(
4265 complement_col_name,
4266 arrow_schema::DataType::Float64,
4267 true,
4268 )));
4269
4270 let new_schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4271 let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
4272 result.push(new_batch);
4273 }
4274 Ok(result)
4275}
4276
4277pub fn apply_prob_complement_composite(
4284 batches: Vec<RecordBatch>,
4285 neg_facts: &[RecordBatch],
4286 join_cols: &[(String, String)],
4287 prob_col: &str,
4288 complement_col_name: &str,
4289) -> datafusion::error::Result<Vec<RecordBatch>> {
4290 use arrow_array::{Array as _, Float64Array, UInt64Array};
4291
4292 let mut prob_map: HashMap<Vec<u64>, f64> = HashMap::new();
4294 for batch in neg_facts {
4295 let right_indices: Vec<usize> = join_cols
4296 .iter()
4297 .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
4298 .collect();
4299 if right_indices.len() != join_cols.len() {
4300 continue;
4301 }
4302 let Ok(prob_idx) = batch.schema().index_of(prob_col) else {
4303 continue;
4304 };
4305 let prob_arr = batch.column(prob_idx);
4306 let probs = prob_arr.as_any().downcast_ref::<Float64Array>();
4307 for row in 0..batch.num_rows() {
4308 let mut key = Vec::with_capacity(right_indices.len());
4309 let mut valid = true;
4310 for &ci in &right_indices {
4311 let col = batch.column(ci);
4312 if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
4313 if vids.is_null(row) {
4314 valid = false;
4315 break;
4316 }
4317 key.push(vids.value(row));
4318 } else {
4319 valid = false;
4320 break;
4321 }
4322 }
4323 if !valid {
4324 continue;
4325 }
4326 let p = probs
4327 .and_then(|arr| {
4328 if arr.is_null(row) {
4329 None
4330 } else {
4331 Some(arr.value(row))
4332 }
4333 })
4334 .unwrap_or(0.0);
4335 prob_map
4337 .entry(key)
4338 .and_modify(|existing| {
4339 *existing = 1.0 - (1.0 - *existing) * (1.0 - p);
4340 })
4341 .or_insert(p);
4342 }
4343 }
4344
4345 let mut result = Vec::new();
4347 for batch in batches {
4348 let left_indices: Vec<usize> = join_cols
4349 .iter()
4350 .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
4351 .collect();
4352 if left_indices.len() != join_cols.len() {
4353 result.push(batch);
4354 continue;
4355 }
4356 let all_u64 = left_indices.iter().all(|&ci| {
4357 batch
4358 .column(ci)
4359 .as_any()
4360 .downcast_ref::<UInt64Array>()
4361 .is_some()
4362 });
4363 if !all_u64 {
4364 result.push(batch);
4365 continue;
4366 }
4367
4368 let complements: Vec<f64> = (0..batch.num_rows())
4369 .map(|row| {
4370 let mut key = Vec::with_capacity(left_indices.len());
4371 for &ci in &left_indices {
4372 let vids = batch
4373 .column(ci)
4374 .as_any()
4375 .downcast_ref::<UInt64Array>()
4376 .unwrap();
4377 if vids.is_null(row) {
4378 return 1.0;
4379 }
4380 key.push(vids.value(row));
4381 }
4382 let p = prob_map.get(&key).copied().unwrap_or(0.0);
4383 1.0 - p
4384 })
4385 .collect();
4386
4387 let complement_arr = Float64Array::from(complements);
4388 let mut columns: Vec<arrow_array::ArrayRef> = batch.columns().to_vec();
4389 columns.push(Arc::new(complement_arr));
4390
4391 let mut fields: Vec<Arc<arrow_schema::Field>> =
4392 batch.schema().fields().iter().cloned().collect();
4393 fields.push(Arc::new(arrow_schema::Field::new(
4394 complement_col_name,
4395 arrow_schema::DataType::Float64,
4396 true,
4397 )));
4398
4399 let new_schema = Arc::new(arrow_schema::Schema::new(fields));
4400 let new_batch = RecordBatch::try_new(new_schema, columns).map_err(arrow_err)?;
4401 result.push(new_batch);
4402 }
4403 Ok(result)
4404}
4405
4406pub fn apply_anti_join_composite(
4412 batches: Vec<RecordBatch>,
4413 neg_facts: &[RecordBatch],
4414 join_cols: &[(String, String)],
4415) -> datafusion::error::Result<Vec<RecordBatch>> {
4416 use arrow::compute::filter_record_batch;
4417 use arrow_array::{Array as _, BooleanArray, UInt64Array};
4418
4419 let mut banned: HashSet<Vec<u64>> = HashSet::new();
4421 for batch in neg_facts {
4422 let right_indices: Vec<usize> = join_cols
4423 .iter()
4424 .filter_map(|(_, rc)| batch.schema().index_of(rc).ok())
4425 .collect();
4426 if right_indices.len() != join_cols.len() {
4427 continue;
4428 }
4429 for row in 0..batch.num_rows() {
4430 let mut key = Vec::with_capacity(right_indices.len());
4431 let mut valid = true;
4432 for &ci in &right_indices {
4433 let col = batch.column(ci);
4434 if let Some(vids) = col.as_any().downcast_ref::<UInt64Array>() {
4435 if vids.is_null(row) {
4436 valid = false;
4437 break;
4438 }
4439 key.push(vids.value(row));
4440 } else {
4441 valid = false;
4442 break;
4443 }
4444 }
4445 if valid {
4446 banned.insert(key);
4447 }
4448 }
4449 }
4450
4451 if banned.is_empty() {
4452 return Ok(batches);
4453 }
4454
4455 let mut result = Vec::new();
4457 for batch in batches {
4458 let left_indices: Vec<usize> = join_cols
4459 .iter()
4460 .filter_map(|(lc, _)| batch.schema().index_of(lc).ok())
4461 .collect();
4462 if left_indices.len() != join_cols.len() {
4463 result.push(batch);
4464 continue;
4465 }
4466 let all_u64 = left_indices.iter().all(|&ci| {
4467 batch
4468 .column(ci)
4469 .as_any()
4470 .downcast_ref::<UInt64Array>()
4471 .is_some()
4472 });
4473 if !all_u64 {
4474 result.push(batch);
4475 continue;
4476 }
4477
4478 let keep: Vec<bool> = (0..batch.num_rows())
4479 .map(|row| {
4480 let mut key = Vec::with_capacity(left_indices.len());
4481 for &ci in &left_indices {
4482 let vids = batch
4483 .column(ci)
4484 .as_any()
4485 .downcast_ref::<UInt64Array>()
4486 .unwrap();
4487 if vids.is_null(row) {
4488 return true; }
4490 key.push(vids.value(row));
4491 }
4492 !banned.contains(&key)
4493 })
4494 .collect();
4495 let keep_arr = BooleanArray::from(keep);
4496 let filtered = filter_record_batch(&batch, &keep_arr).map_err(arrow_err)?;
4497 if filtered.num_rows() > 0 {
4498 result.push(filtered);
4499 }
4500 }
4501 Ok(result)
4502}
4503
4504pub fn multiply_prob_factors(
4515 batches: Vec<RecordBatch>,
4516 prob_col: Option<&str>,
4517 complement_cols: &[String],
4518) -> datafusion::error::Result<Vec<RecordBatch>> {
4519 use arrow_array::{Array as _, Float64Array};
4520
4521 let mut result = Vec::with_capacity(batches.len());
4522
4523 for batch in batches {
4524 if batch.num_rows() == 0 {
4525 let keep: Vec<usize> = batch
4527 .schema()
4528 .fields()
4529 .iter()
4530 .enumerate()
4531 .filter(|(_, f)| !complement_cols.contains(f.name()))
4532 .map(|(i, _)| i)
4533 .collect();
4534 let fields: Vec<_> = keep
4535 .iter()
4536 .map(|&i| batch.schema().field(i).clone())
4537 .collect();
4538 let cols: Vec<_> = keep.iter().map(|&i| batch.column(i).clone()).collect();
4539 let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4540 result.push(
4541 RecordBatch::try_new(schema, cols).map_err(|e| {
4542 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
4543 })?,
4544 );
4545 continue;
4546 }
4547
4548 let num_rows = batch.num_rows();
4549
4550 let mut combined = vec![1.0f64; num_rows];
4552 for col_name in complement_cols {
4553 if let Ok(idx) = batch.schema().index_of(col_name) {
4554 let arr = batch
4555 .column(idx)
4556 .as_any()
4557 .downcast_ref::<Float64Array>()
4558 .ok_or_else(|| {
4559 datafusion::error::DataFusionError::Internal(format!(
4560 "Expected Float64 for complement column {col_name}"
4561 ))
4562 })?;
4563 for (i, val) in combined.iter_mut().enumerate().take(num_rows) {
4564 if !arr.is_null(i) {
4565 *val *= arr.value(i);
4566 }
4567 }
4568 }
4569 }
4570
4571 let final_prob: Vec<f64> = if let Some(prob_name) = prob_col {
4573 if let Ok(idx) = batch.schema().index_of(prob_name) {
4574 let arr = batch
4575 .column(idx)
4576 .as_any()
4577 .downcast_ref::<Float64Array>()
4578 .ok_or_else(|| {
4579 datafusion::error::DataFusionError::Internal(format!(
4580 "Expected Float64 for PROB column {prob_name}"
4581 ))
4582 })?;
4583 (0..num_rows)
4584 .map(|i| {
4585 if arr.is_null(i) {
4586 combined[i]
4587 } else {
4588 arr.value(i) * combined[i]
4589 }
4590 })
4591 .collect()
4592 } else {
4593 combined
4594 }
4595 } else {
4596 combined
4597 };
4598
4599 let new_prob_array: arrow_array::ArrayRef =
4600 std::sync::Arc::new(Float64Array::from(final_prob));
4601
4602 let mut fields = Vec::new();
4604 let mut columns = Vec::new();
4605
4606 for (idx, field) in batch.schema().fields().iter().enumerate() {
4607 if complement_cols.contains(field.name()) {
4608 continue;
4609 }
4610 if prob_col.is_some_and(|p| field.name() == p) {
4611 fields.push(field.clone());
4612 columns.push(new_prob_array.clone());
4613 } else {
4614 fields.push(field.clone());
4615 columns.push(batch.column(idx).clone());
4616 }
4617 }
4618
4619 let schema = std::sync::Arc::new(arrow_schema::Schema::new(fields));
4620 result.push(RecordBatch::try_new(schema, columns).map_err(arrow_err)?);
4621 }
4622
4623 Ok(result)
4624}
4625
4626fn update_derived_scan_handles(
4631 registry: &DerivedScanRegistry,
4632 states: &[FixpointState],
4633 current_rule_idx: usize,
4634 rules: &[FixpointRulePlan],
4635) {
4636 let current_rule_name = &rules[current_rule_idx].name;
4637
4638 for entry in ®istry.entries {
4639 let source_state_idx = rules.iter().position(|r| r.name == entry.rule_name);
4641 let Some(source_idx) = source_state_idx else {
4642 continue;
4643 };
4644
4645 let is_self = entry.rule_name == *current_rule_name;
4646 let data = if is_self && !rules[current_rule_idx].non_linear {
4647 states[source_idx].all_delta().to_vec()
4649 } else {
4650 states[source_idx].all_facts().to_vec()
4653 };
4654
4655 let data = if data.is_empty() || data.iter().all(|b| b.num_rows() == 0) {
4657 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
4658 } else {
4659 data
4660 };
4661
4662 let mut guard = entry.data.write();
4663 *guard = data;
4664 }
4665}
4666
4667pub struct DerivedScanExec {
4677 data: Arc<RwLock<Vec<RecordBatch>>>,
4678 schema: SchemaRef,
4679 properties: Arc<PlanProperties>,
4680}
4681
4682impl DerivedScanExec {
4683 pub fn new(data: Arc<RwLock<Vec<RecordBatch>>>, schema: SchemaRef) -> Self {
4684 let properties = compute_plan_properties(Arc::clone(&schema));
4685 Self {
4686 data,
4687 schema,
4688 properties,
4689 }
4690 }
4691}
4692
4693impl fmt::Debug for DerivedScanExec {
4694 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4695 f.debug_struct("DerivedScanExec")
4696 .field("schema", &self.schema)
4697 .finish()
4698 }
4699}
4700
4701impl DisplayAs for DerivedScanExec {
4702 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4703 write!(f, "DerivedScanExec")
4704 }
4705}
4706
4707impl ExecutionPlan for DerivedScanExec {
4708 fn name(&self) -> &str {
4709 "DerivedScanExec"
4710 }
4711 fn as_any(&self) -> &dyn Any {
4712 self
4713 }
4714 fn schema(&self) -> SchemaRef {
4715 Arc::clone(&self.schema)
4716 }
4717 fn properties(&self) -> &Arc<PlanProperties> {
4718 &self.properties
4719 }
4720 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
4721 vec![]
4722 }
4723 fn with_new_children(
4724 self: Arc<Self>,
4725 _children: Vec<Arc<dyn ExecutionPlan>>,
4726 ) -> DFResult<Arc<dyn ExecutionPlan>> {
4727 Ok(self)
4728 }
4729 fn execute(
4730 &self,
4731 _partition: usize,
4732 _context: Arc<TaskContext>,
4733 ) -> DFResult<SendableRecordBatchStream> {
4734 let batches = {
4735 let guard = self.data.read();
4736 if guard.is_empty() {
4737 vec![RecordBatch::new_empty(Arc::clone(&self.schema))]
4738 } else {
4739 guard
4745 .iter()
4746 .map(|b| {
4747 RecordBatch::try_new(Arc::clone(&self.schema), b.columns().to_vec())
4748 .map_err(|e| {
4749 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
4750 })
4751 })
4752 .collect::<DFResult<Vec<_>>>()?
4753 }
4754 };
4755 Ok(Box::pin(MemoryStream::try_new(
4756 batches,
4757 Arc::clone(&self.schema),
4758 None,
4759 )?))
4760 }
4761}
4762
4763struct InMemoryExec {
4772 batches: Vec<RecordBatch>,
4773 schema: SchemaRef,
4774 properties: Arc<PlanProperties>,
4775}
4776
4777impl InMemoryExec {
4778 fn new(batches: Vec<RecordBatch>, schema: SchemaRef) -> Self {
4779 let properties = compute_plan_properties(Arc::clone(&schema));
4780 Self {
4781 batches,
4782 schema,
4783 properties,
4784 }
4785 }
4786}
4787
4788impl fmt::Debug for InMemoryExec {
4789 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4790 f.debug_struct("InMemoryExec")
4791 .field("num_batches", &self.batches.len())
4792 .field("schema", &self.schema)
4793 .finish()
4794 }
4795}
4796
4797impl DisplayAs for InMemoryExec {
4798 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4799 write!(f, "InMemoryExec: batches={}", self.batches.len())
4800 }
4801}
4802
4803impl ExecutionPlan for InMemoryExec {
4804 fn name(&self) -> &str {
4805 "InMemoryExec"
4806 }
4807 fn as_any(&self) -> &dyn Any {
4808 self
4809 }
4810 fn schema(&self) -> SchemaRef {
4811 Arc::clone(&self.schema)
4812 }
4813 fn properties(&self) -> &Arc<PlanProperties> {
4814 &self.properties
4815 }
4816 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
4817 vec![]
4818 }
4819 fn with_new_children(
4820 self: Arc<Self>,
4821 _children: Vec<Arc<dyn ExecutionPlan>>,
4822 ) -> DFResult<Arc<dyn ExecutionPlan>> {
4823 Ok(self)
4824 }
4825 fn execute(
4826 &self,
4827 _partition: usize,
4828 _context: Arc<TaskContext>,
4829 ) -> DFResult<SendableRecordBatchStream> {
4830 Ok(Box::pin(MemoryStream::try_new(
4831 self.batches.clone(),
4832 Arc::clone(&self.schema),
4833 None,
4834 )?))
4835 }
4836}
4837
4838fn apply_having_filter(
4848 batches: Vec<RecordBatch>,
4849 having_exprs: &[Expr],
4850 schema: &SchemaRef,
4851 task_ctx: &Arc<TaskContext>,
4852) -> DFResult<Vec<RecordBatch>> {
4853 use arrow::compute::{and, filter_record_batch};
4854 use arrow_array::BooleanArray;
4855 use datafusion::common::DFSchema;
4856 use datafusion::logical_expr::LogicalPlanBuilder;
4857 use datafusion::logical_expr::execution_props::ExecutionProps;
4858 use datafusion::optimizer::AnalyzerRule;
4859 use datafusion::optimizer::analyzer::type_coercion::TypeCoercion;
4860 use datafusion::physical_expr::create_physical_expr;
4861
4862 if batches.is_empty() {
4863 return Ok(batches);
4864 }
4865
4866 let df_schema = DFSchema::try_from(schema.as_ref().clone()).map_err(|e| {
4868 datafusion::common::DataFusionError::Internal(format!("HAVING schema conversion: {e}"))
4869 })?;
4870
4871 let config = (**task_ctx.session_config().options()).clone();
4876 let props = ExecutionProps::new();
4877
4878 let physical_exprs: Vec<Arc<dyn datafusion::physical_expr::PhysicalExpr>> = having_exprs
4884 .iter()
4885 .map(|expr| {
4886 let df_expr = crate::query::df_expr::cypher_expr_to_df(expr, None).map_err(|e| {
4887 datafusion::common::DataFusionError::Internal(format!(
4888 "HAVING expression conversion: {e}"
4889 ))
4890 })?;
4891
4892 let empty = datafusion::logical_expr::LogicalPlan::EmptyRelation(
4896 datafusion::logical_expr::EmptyRelation {
4897 produce_one_row: false,
4898 schema: Arc::new(df_schema.clone()),
4899 },
4900 );
4901 let filter_plan = LogicalPlanBuilder::from(empty)
4902 .filter(df_expr.clone())?
4903 .build()?;
4904 let coerced_expr = match TypeCoercion::new().analyze(filter_plan, &config) {
4905 Ok(datafusion::logical_expr::LogicalPlan::Filter(f)) => f.predicate,
4906 _ => df_expr,
4907 };
4908
4909 create_physical_expr(&coerced_expr, &df_schema, &props)
4910 })
4911 .collect::<DFResult<Vec<_>>>()?;
4912
4913 let mut result = Vec::new();
4914 for batch in batches {
4915 let mut mask: Option<BooleanArray> = None;
4917 for phys_expr in &physical_exprs {
4918 let value = phys_expr.evaluate(&batch)?;
4919 let arr = value.into_array(batch.num_rows())?;
4920 let bool_arr = arr.as_any().downcast_ref::<BooleanArray>().ok_or_else(|| {
4921 datafusion::common::DataFusionError::Internal(
4922 "HAVING condition must evaluate to boolean".into(),
4923 )
4924 })?;
4925 mask = Some(match mask {
4926 None => bool_arr.clone(),
4927 Some(prev) => and(&prev, bool_arr).map_err(arrow_err)?,
4928 });
4929 }
4930 if let Some(ref m) = mask {
4931 let filtered = filter_record_batch(&batch, m).map_err(arrow_err)?;
4932 if filtered.num_rows() > 0 {
4933 result.push(filtered);
4934 }
4935 } else {
4936 result.push(batch);
4937 }
4938 }
4939 Ok(result)
4940}
4941
4942#[allow(
4944 clippy::too_many_arguments,
4945 reason = "context bundle would be over-engineering for one call site"
4946)]
4947pub(crate) async fn apply_post_fixpoint_chain(
4948 facts: Vec<RecordBatch>,
4949 rule: &FixpointRulePlan,
4950 task_ctx: &Arc<TaskContext>,
4951 strict_probability_domain: bool,
4952 probability_epsilon: f64,
4953 semiring_kind: SemiringKind,
4954 provenance_tracker: Option<Arc<ProvenanceStore>>,
4955 top_k_proofs_k: usize,
4956 registry: Option<Arc<DerivedScanRegistry>>,
4957) -> DFResult<Vec<RecordBatch>> {
4958 if !rule.has_fold && !rule.has_best_by && !rule.has_priority && rule.having.is_empty() {
4959 return Ok(facts);
4960 }
4961
4962 let schema = facts
4967 .iter()
4968 .find(|b| b.num_rows() > 0)
4969 .map(|b| b.schema())
4970 .unwrap_or_else(|| Arc::clone(&rule.yield_schema));
4971
4972 let topk_k: Option<usize> = match semiring_kind {
4986 SemiringKind::TopKProofs { k } if k > 0 => Some(k as usize),
4987 _ => None,
4988 };
4989 let body_support_map: Option<Arc<HashMap<Vec<u8>, Vec<ProofTerm>>>> = if topk_k.is_some()
4990 && !rule.has_priority
4991 && let Some(registry) = registry.as_ref()
4992 {
4993 let mut map: HashMap<Vec<u8>, Vec<ProofTerm>> = HashMap::new();
4994 for batch in &facts {
4995 let all_indices: Vec<usize> = (0..batch.num_columns()).collect();
4996 for row_idx in 0..batch.num_rows() {
4997 let support = collect_is_ref_inputs_for_body_row(rule, batch, row_idx, registry);
4998 if support.is_empty() {
4999 continue;
5000 }
5001 let hash = fact_hash_key(batch, &all_indices, row_idx);
5002 map.insert(hash, support);
5003 }
5004 }
5005 if map.is_empty() {
5006 None
5007 } else {
5008 Some(Arc::new(map))
5009 }
5010 } else {
5011 None
5012 };
5013
5014 let input: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(facts, schema.clone()));
5015
5016 let key_column_indices: Vec<usize> = rule
5021 .key_column_indices
5022 .iter()
5023 .filter_map(|&i| {
5024 let name = rule.yield_schema.field(i).name();
5025 schema.index_of(name).ok()
5026 })
5027 .collect();
5028
5029 let current: Arc<dyn ExecutionPlan> = if rule.has_priority {
5033 let priority_schema = input.schema();
5034 let priority_idx = priority_schema.index_of("__priority").map_err(|_| {
5035 datafusion::common::DataFusionError::Internal(
5036 "PRIORITY rule missing __priority column".to_string(),
5037 )
5038 })?;
5039 Arc::new(PriorityExec::new(
5040 input,
5041 key_column_indices.clone(),
5042 priority_idx,
5043 ))
5044 } else {
5045 input
5046 };
5047
5048 let current: Arc<dyn ExecutionPlan> = if rule.has_fold && !rule.fold_bindings.is_empty() {
5050 Arc::new(FoldExec::new_with_topk(
5051 current,
5052 key_column_indices.clone(),
5053 rule.fold_bindings.clone(),
5054 strict_probability_domain,
5055 probability_epsilon,
5056 semiring_kind,
5057 provenance_tracker.clone(),
5058 topk_k.unwrap_or(top_k_proofs_k),
5059 body_support_map.clone(),
5060 ))
5061 } else {
5062 current
5063 };
5064
5065 let current: Arc<dyn ExecutionPlan> = if !rule.having.is_empty() {
5067 let batches = collect_all_partitions(¤t, Arc::clone(task_ctx)).await?;
5068 let filtered = apply_having_filter(batches, &rule.having, ¤t.schema(), task_ctx)?;
5069 if filtered.is_empty() {
5070 return Ok(filtered);
5071 }
5072 Arc::new(InMemoryExec::new(filtered, Arc::clone(¤t.schema())))
5073 } else {
5074 current
5075 };
5076
5077 let current: Arc<dyn ExecutionPlan> = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
5079 Arc::new(BestByExec::new(
5080 current,
5081 key_column_indices.clone(),
5082 rule.best_by_criteria.clone(),
5083 rule.deterministic,
5084 ))
5085 } else {
5086 current
5087 };
5088
5089 collect_all_partitions(¤t, Arc::clone(task_ctx)).await
5090}
5091
5092pub struct FixpointExec {
5101 rules: Vec<FixpointRulePlan>,
5102 max_iterations: usize,
5103 timeout: Duration,
5104 graph_ctx: Arc<GraphExecutionContext>,
5105 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5106 storage: Arc<StorageManager>,
5107 schema_info: Arc<UniSchema>,
5108 params: HashMap<String, Value>,
5109 derived_scan_registry: Arc<DerivedScanRegistry>,
5110 output_schema: SchemaRef,
5111 properties: Arc<PlanProperties>,
5112 metrics: ExecutionPlanMetricsSet,
5113 max_derived_bytes: usize,
5114 derivation_tracker: Option<Arc<ProvenanceStore>>,
5116 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5118 strict_probability_domain: bool,
5119 probability_epsilon: f64,
5120 exact_probability: bool,
5121 max_bdd_variables: usize,
5122 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5124 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5126 top_k_proofs: usize,
5128 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5130 semiring_kind: SemiringKind,
5132 classifier_registry: Arc<ClassifierRegistry>,
5136 classifier_cache: Option<Arc<ModelInvocationCache>>,
5142 #[allow(
5149 dead_code,
5150 reason = "boundary plumbing; read by EXPLAIN via LocyModelInvokeExec"
5151 )]
5152 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
5153}
5154
5155impl fmt::Debug for FixpointExec {
5156 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5157 f.debug_struct("FixpointExec")
5158 .field("rules_count", &self.rules.len())
5159 .field("max_iterations", &self.max_iterations)
5160 .field("timeout", &self.timeout)
5161 .field("output_schema", &self.output_schema)
5162 .field("max_derived_bytes", &self.max_derived_bytes)
5163 .finish_non_exhaustive()
5164 }
5165}
5166
5167impl FixpointExec {
5168 #[expect(
5170 clippy::too_many_arguments,
5171 reason = "FixpointExec configuration needs all context"
5172 )]
5173 #[deprecated(
5174 note = "use `new_with_semiring_classifiers_and_cache` (or the lighter \
5175 `new_with_semiring_and_classifiers` / `new_with_semiring`) — \
5176 this legacy ctor defaults the semiring to AddMultProb and \
5177 ships no classifier registry, which the Phase B+ runtime needs \
5178 explicitly. To be removed after C0 Stage 2."
5179 )]
5180 pub fn new(
5181 rules: Vec<FixpointRulePlan>,
5182 max_iterations: usize,
5183 timeout: Duration,
5184 graph_ctx: Arc<GraphExecutionContext>,
5185 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5186 storage: Arc<StorageManager>,
5187 schema_info: Arc<UniSchema>,
5188 params: HashMap<String, Value>,
5189 derived_scan_registry: Arc<DerivedScanRegistry>,
5190 output_schema: SchemaRef,
5191 max_derived_bytes: usize,
5192 derivation_tracker: Option<Arc<ProvenanceStore>>,
5193 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5194 strict_probability_domain: bool,
5195 probability_epsilon: f64,
5196 exact_probability: bool,
5197 max_bdd_variables: usize,
5198 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5199 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5200 top_k_proofs: usize,
5201 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5202 ) -> Self {
5203 Self::new_with_semiring_and_classifiers(
5204 rules,
5205 max_iterations,
5206 timeout,
5207 graph_ctx,
5208 session_ctx,
5209 storage,
5210 schema_info,
5211 params,
5212 derived_scan_registry,
5213 output_schema,
5214 max_derived_bytes,
5215 derivation_tracker,
5216 iteration_counts,
5217 strict_probability_domain,
5218 probability_epsilon,
5219 exact_probability,
5220 max_bdd_variables,
5221 warnings_slot,
5222 approximate_slot,
5223 top_k_proofs,
5224 timeout_flag,
5225 SemiringKind::AddMultProb,
5226 Arc::new(ClassifierRegistry::new()),
5227 )
5228 }
5229
5230 #[expect(
5234 clippy::too_many_arguments,
5235 reason = "FixpointExec configuration needs all context"
5236 )]
5237 pub fn new_with_semiring(
5238 rules: Vec<FixpointRulePlan>,
5239 max_iterations: usize,
5240 timeout: Duration,
5241 graph_ctx: Arc<GraphExecutionContext>,
5242 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5243 storage: Arc<StorageManager>,
5244 schema_info: Arc<UniSchema>,
5245 params: HashMap<String, Value>,
5246 derived_scan_registry: Arc<DerivedScanRegistry>,
5247 output_schema: SchemaRef,
5248 max_derived_bytes: usize,
5249 derivation_tracker: Option<Arc<ProvenanceStore>>,
5250 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5251 strict_probability_domain: bool,
5252 probability_epsilon: f64,
5253 exact_probability: bool,
5254 max_bdd_variables: usize,
5255 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5256 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5257 top_k_proofs: usize,
5258 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5259 semiring_kind: SemiringKind,
5260 ) -> Self {
5261 Self::new_with_semiring_and_classifiers(
5262 rules,
5263 max_iterations,
5264 timeout,
5265 graph_ctx,
5266 session_ctx,
5267 storage,
5268 schema_info,
5269 params,
5270 derived_scan_registry,
5271 output_schema,
5272 max_derived_bytes,
5273 derivation_tracker,
5274 iteration_counts,
5275 strict_probability_domain,
5276 probability_epsilon,
5277 exact_probability,
5278 max_bdd_variables,
5279 warnings_slot,
5280 approximate_slot,
5281 top_k_proofs,
5282 timeout_flag,
5283 semiring_kind,
5284 Arc::new(ClassifierRegistry::new()),
5285 )
5286 }
5287
5288 #[expect(
5292 clippy::too_many_arguments,
5293 reason = "FixpointExec configuration needs all context"
5294 )]
5295 pub fn new_with_semiring_and_classifiers(
5296 rules: Vec<FixpointRulePlan>,
5297 max_iterations: usize,
5298 timeout: Duration,
5299 graph_ctx: Arc<GraphExecutionContext>,
5300 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5301 storage: Arc<StorageManager>,
5302 schema_info: Arc<UniSchema>,
5303 params: HashMap<String, Value>,
5304 derived_scan_registry: Arc<DerivedScanRegistry>,
5305 output_schema: SchemaRef,
5306 max_derived_bytes: usize,
5307 derivation_tracker: Option<Arc<ProvenanceStore>>,
5308 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5309 strict_probability_domain: bool,
5310 probability_epsilon: f64,
5311 exact_probability: bool,
5312 max_bdd_variables: usize,
5313 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5314 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5315 top_k_proofs: usize,
5316 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5317 semiring_kind: SemiringKind,
5318 classifier_registry: Arc<ClassifierRegistry>,
5319 ) -> Self {
5320 Self::new_with_semiring_classifiers_and_cache(
5321 rules,
5322 max_iterations,
5323 timeout,
5324 graph_ctx,
5325 session_ctx,
5326 storage,
5327 schema_info,
5328 params,
5329 derived_scan_registry,
5330 output_schema,
5331 max_derived_bytes,
5332 derivation_tracker,
5333 iteration_counts,
5334 strict_probability_domain,
5335 probability_epsilon,
5336 exact_probability,
5337 max_bdd_variables,
5338 warnings_slot,
5339 approximate_slot,
5340 top_k_proofs,
5341 timeout_flag,
5342 semiring_kind,
5343 classifier_registry,
5344 None,
5345 None,
5346 )
5347 }
5348
5349 #[expect(
5353 clippy::too_many_arguments,
5354 reason = "FixpointExec configuration needs all context"
5355 )]
5356 pub fn new_with_semiring_classifiers_and_cache(
5357 rules: Vec<FixpointRulePlan>,
5358 max_iterations: usize,
5359 timeout: Duration,
5360 graph_ctx: Arc<GraphExecutionContext>,
5361 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
5362 storage: Arc<StorageManager>,
5363 schema_info: Arc<UniSchema>,
5364 params: HashMap<String, Value>,
5365 derived_scan_registry: Arc<DerivedScanRegistry>,
5366 output_schema: SchemaRef,
5367 max_derived_bytes: usize,
5368 derivation_tracker: Option<Arc<ProvenanceStore>>,
5369 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
5370 strict_probability_domain: bool,
5371 probability_epsilon: f64,
5372 exact_probability: bool,
5373 max_bdd_variables: usize,
5374 warnings_slot: Arc<StdRwLock<Vec<RuntimeWarning>>>,
5375 approximate_slot: Arc<StdRwLock<HashMap<String, Vec<String>>>>,
5376 top_k_proofs: usize,
5377 timeout_flag: Arc<std::sync::atomic::AtomicU8>,
5378 semiring_kind: SemiringKind,
5379 classifier_registry: Arc<ClassifierRegistry>,
5380 classifier_cache: Option<Arc<ModelInvocationCache>>,
5381 classifier_provenance_store: Option<Arc<uni_locy::NeuralProvenanceStore>>,
5382 ) -> Self {
5383 let properties = compute_plan_properties(Arc::clone(&output_schema));
5384 Self {
5385 rules,
5386 max_iterations,
5387 timeout,
5388 graph_ctx,
5389 session_ctx,
5390 storage,
5391 schema_info,
5392 params,
5393 derived_scan_registry,
5394 output_schema,
5395 properties,
5396 metrics: ExecutionPlanMetricsSet::new(),
5397 max_derived_bytes,
5398 derivation_tracker,
5399 iteration_counts,
5400 strict_probability_domain,
5401 probability_epsilon,
5402 exact_probability,
5403 max_bdd_variables,
5404 warnings_slot,
5405 approximate_slot,
5406 top_k_proofs,
5407 timeout_flag,
5408 semiring_kind,
5409 classifier_registry,
5410 classifier_cache,
5411 classifier_provenance_store,
5412 }
5413 }
5414
5415 pub fn iteration_counts(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
5417 Arc::clone(&self.iteration_counts)
5418 }
5419}
5420
5421impl DisplayAs for FixpointExec {
5422 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5423 write!(
5424 f,
5425 "FixpointExec: rules=[{}], max_iter={}, timeout={:?}",
5426 self.rules
5427 .iter()
5428 .map(|r| r.name.as_str())
5429 .collect::<Vec<_>>()
5430 .join(", "),
5431 self.max_iterations,
5432 self.timeout,
5433 )
5434 }
5435}
5436
5437impl ExecutionPlan for FixpointExec {
5438 fn name(&self) -> &str {
5439 "FixpointExec"
5440 }
5441
5442 fn as_any(&self) -> &dyn Any {
5443 self
5444 }
5445
5446 fn schema(&self) -> SchemaRef {
5447 Arc::clone(&self.output_schema)
5448 }
5449
5450 fn properties(&self) -> &Arc<PlanProperties> {
5451 &self.properties
5452 }
5453
5454 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
5455 vec![]
5457 }
5458
5459 fn with_new_children(
5460 self: Arc<Self>,
5461 children: Vec<Arc<dyn ExecutionPlan>>,
5462 ) -> DFResult<Arc<dyn ExecutionPlan>> {
5463 if !children.is_empty() {
5464 return Err(datafusion::error::DataFusionError::Plan(
5465 "FixpointExec has no children".to_string(),
5466 ));
5467 }
5468 Ok(self)
5469 }
5470
5471 fn execute(
5472 &self,
5473 partition: usize,
5474 _context: Arc<TaskContext>,
5475 ) -> DFResult<SendableRecordBatchStream> {
5476 let metrics = BaselineMetrics::new(&self.metrics, partition);
5477
5478 let rules = self
5480 .rules
5481 .iter()
5482 .map(|r| {
5483 FixpointRulePlan {
5487 name: r.name.clone(),
5488 clauses: r
5489 .clauses
5490 .iter()
5491 .map(|c| FixpointClausePlan {
5492 body_logical: c.body_logical.clone(),
5493 is_ref_bindings: c.is_ref_bindings.clone(),
5494 priority: c.priority,
5495 along_bindings: c.along_bindings.clone(),
5496 model_invocations: c.model_invocations.clone(),
5497 })
5498 .collect(),
5499 yield_schema: Arc::clone(&r.yield_schema),
5500 key_column_indices: r.key_column_indices.clone(),
5501 priority: r.priority,
5502 has_fold: r.has_fold,
5503 fold_bindings: r.fold_bindings.clone(),
5504 having: r.having.clone(),
5505 has_best_by: r.has_best_by,
5506 best_by_criteria: r.best_by_criteria.clone(),
5507 has_priority: r.has_priority,
5508 deterministic: r.deterministic,
5509 prob_column_name: r.prob_column_name.clone(),
5510 non_linear: r.non_linear,
5511 }
5512 })
5513 .collect();
5514
5515 let max_iterations = self.max_iterations;
5516 let timeout = self.timeout;
5517 let graph_ctx = Arc::clone(&self.graph_ctx);
5518 let session_ctx = Arc::clone(&self.session_ctx);
5519 let storage = Arc::clone(&self.storage);
5520 let schema_info = Arc::clone(&self.schema_info);
5521 let params = self.params.clone();
5522 let registry = Arc::clone(&self.derived_scan_registry);
5523 let output_schema = Arc::clone(&self.output_schema);
5524 let max_derived_bytes = self.max_derived_bytes;
5525 let derivation_tracker = self.derivation_tracker.clone();
5526 let iteration_counts = Arc::clone(&self.iteration_counts);
5527 let strict_probability_domain = self.strict_probability_domain;
5528 let probability_epsilon = self.probability_epsilon;
5529 let exact_probability = self.exact_probability;
5530 let max_bdd_variables = self.max_bdd_variables;
5531 let warnings_slot = Arc::clone(&self.warnings_slot);
5532 let approximate_slot = Arc::clone(&self.approximate_slot);
5533 let top_k_proofs = self.top_k_proofs;
5534 let timeout_flag = Arc::clone(&self.timeout_flag);
5535 let semiring_kind = self.semiring_kind;
5536 let classifier_registry = Arc::clone(&self.classifier_registry);
5537 let classifier_cache = self.classifier_cache.as_ref().map(Arc::clone);
5538 let classifier_provenance_store = self.classifier_provenance_store.as_ref().map(Arc::clone);
5539
5540 let fut = async move {
5541 run_fixpoint_loop(
5542 rules,
5543 max_iterations,
5544 timeout,
5545 graph_ctx,
5546 session_ctx,
5547 storage,
5548 schema_info,
5549 params,
5550 registry,
5551 output_schema,
5552 max_derived_bytes,
5553 derivation_tracker,
5554 iteration_counts,
5555 strict_probability_domain,
5556 probability_epsilon,
5557 exact_probability,
5558 max_bdd_variables,
5559 warnings_slot,
5560 approximate_slot,
5561 top_k_proofs,
5562 timeout_flag,
5563 semiring_kind,
5564 classifier_registry,
5565 classifier_cache,
5566 classifier_provenance_store,
5567 )
5568 .await
5569 };
5570
5571 Ok(Box::pin(FixpointStream {
5572 state: FixpointStreamState::Running(Box::pin(fut)),
5573 schema: Arc::clone(&self.output_schema),
5574 metrics,
5575 }))
5576 }
5577
5578 fn metrics(&self) -> Option<MetricsSet> {
5579 Some(self.metrics.clone_inner())
5580 }
5581}
5582
5583enum FixpointStreamState {
5588 Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
5590 Emitting(Vec<RecordBatch>, usize),
5592 Done,
5594}
5595
5596struct FixpointStream {
5597 state: FixpointStreamState,
5598 schema: SchemaRef,
5599 metrics: BaselineMetrics,
5600}
5601
5602impl Stream for FixpointStream {
5603 type Item = DFResult<RecordBatch>;
5604
5605 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
5606 let this = self.get_mut();
5607 let metrics = this.metrics.clone();
5608 let _timer = metrics.elapsed_compute().timer();
5609 loop {
5610 match &mut this.state {
5611 FixpointStreamState::Running(fut) => match fut.as_mut().poll(cx) {
5612 Poll::Ready(Ok(batches)) => {
5613 if batches.is_empty() {
5614 this.state = FixpointStreamState::Done;
5615 return Poll::Ready(None);
5616 }
5617 this.state = FixpointStreamState::Emitting(batches, 0);
5618 }
5620 Poll::Ready(Err(e)) => {
5621 this.state = FixpointStreamState::Done;
5622 return Poll::Ready(Some(Err(e)));
5623 }
5624 Poll::Pending => return Poll::Pending,
5625 },
5626 FixpointStreamState::Emitting(batches, idx) => {
5627 if *idx >= batches.len() {
5628 this.state = FixpointStreamState::Done;
5629 return Poll::Ready(None);
5630 }
5631 let batch = batches[*idx].clone();
5632 *idx += 1;
5633 this.metrics.record_output(batch.num_rows());
5634 return Poll::Ready(Some(Ok(batch)));
5635 }
5636 FixpointStreamState::Done => return Poll::Ready(None),
5637 }
5638 }
5639 }
5640}
5641
5642impl RecordBatchStream for FixpointStream {
5643 fn schema(&self) -> SchemaRef {
5644 Arc::clone(&self.schema)
5645 }
5646}
5647
5648#[cfg(test)]
5653mod tests {
5654 use super::*;
5655 use arrow_array::{Float64Array, Int64Array, StringArray};
5656 use arrow_schema::{DataType, Field, Schema};
5657
5658 fn test_schema() -> SchemaRef {
5659 Arc::new(Schema::new(vec![
5660 Field::new("name", DataType::Utf8, true),
5661 Field::new("value", DataType::Int64, true),
5662 ]))
5663 }
5664
5665 fn make_batch(names: &[&str], values: &[i64]) -> RecordBatch {
5666 RecordBatch::try_new(
5667 test_schema(),
5668 vec![
5669 Arc::new(StringArray::from(
5670 names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
5671 )),
5672 Arc::new(Int64Array::from(values.to_vec())),
5673 ],
5674 )
5675 .unwrap()
5676 }
5677
5678 #[tokio::test]
5681 async fn test_fixpoint_state_empty_facts_adds_all() {
5682 let schema = test_schema();
5683 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5684
5685 let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
5686 let changed = state.merge_delta(vec![batch], None).await.unwrap();
5687
5688 assert!(changed);
5689 assert_eq!(state.all_facts().len(), 1);
5690 assert_eq!(state.all_facts()[0].num_rows(), 3);
5691 assert_eq!(state.all_delta().len(), 1);
5692 assert_eq!(state.all_delta()[0].num_rows(), 3);
5693 }
5694
5695 #[tokio::test]
5696 async fn test_fixpoint_state_exact_duplicates_excluded() {
5697 let schema = test_schema();
5698 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5699
5700 let batch1 = make_batch(&["a", "b"], &[1, 2]);
5701 state.merge_delta(vec![batch1], None).await.unwrap();
5702
5703 let batch2 = make_batch(&["a", "b"], &[1, 2]);
5705 let changed = state.merge_delta(vec![batch2], None).await.unwrap();
5706 assert!(!changed);
5707 assert!(
5708 state.all_delta().is_empty() || state.all_delta().iter().all(|b| b.num_rows() == 0)
5709 );
5710 }
5711
5712 #[tokio::test]
5713 async fn test_fixpoint_state_partial_overlap() {
5714 let schema = test_schema();
5715 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5716
5717 let batch1 = make_batch(&["a", "b"], &[1, 2]);
5718 state.merge_delta(vec![batch1], None).await.unwrap();
5719
5720 let batch2 = make_batch(&["a", "c"], &[1, 3]);
5722 let changed = state.merge_delta(vec![batch2], None).await.unwrap();
5723 assert!(changed);
5724
5725 let delta_rows: usize = state.all_delta().iter().map(|b| b.num_rows()).sum();
5727 assert_eq!(delta_rows, 1);
5728
5729 let total_rows: usize = state.all_facts().iter().map(|b| b.num_rows()).sum();
5731 assert_eq!(total_rows, 3);
5732 }
5733
5734 #[tokio::test]
5735 async fn test_fixpoint_state_convergence() {
5736 let schema = test_schema();
5737 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None, false);
5738
5739 let batch = make_batch(&["a"], &[1]);
5740 state.merge_delta(vec![batch], None).await.unwrap();
5741
5742 let changed = state.merge_delta(vec![], None).await.unwrap();
5744 assert!(!changed);
5745 assert!(state.is_converged());
5746 }
5747
5748 #[test]
5751 fn test_row_dedup_persistent_across_calls() {
5752 let schema = test_schema();
5755 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5756
5757 let batch1 = make_batch(&["a", "b"], &[1, 2]);
5758 let delta1 = rd.compute_delta(&[batch1], &schema).unwrap();
5759 let rows1: usize = delta1.iter().map(|b| b.num_rows()).sum();
5761 assert_eq!(rows1, 2);
5762
5763 let batch2 = make_batch(&["a", "b"], &[1, 2]);
5765 let delta2 = rd.compute_delta(&[batch2], &schema).unwrap();
5766 let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
5767 assert_eq!(rows2, 0);
5768
5769 let batch3 = make_batch(&["a", "c"], &[1, 3]);
5771 let delta3 = rd.compute_delta(&[batch3], &schema).unwrap();
5772 let rows3: usize = delta3.iter().map(|b| b.num_rows()).sum();
5773 assert_eq!(rows3, 1);
5774 }
5775
5776 #[test]
5777 fn test_row_dedup_null_handling() {
5778 use arrow_array::StringArray;
5779 use arrow_schema::{DataType, Field, Schema};
5780
5781 let schema: SchemaRef = Arc::new(Schema::new(vec![
5782 Field::new("a", DataType::Utf8, true),
5783 Field::new("b", DataType::Int64, true),
5784 ]));
5785 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5786
5787 let batch_nulls = RecordBatch::try_new(
5789 Arc::clone(&schema),
5790 vec![
5791 Arc::new(StringArray::from(vec![None::<&str>, None::<&str>])),
5792 Arc::new(arrow_array::Int64Array::from(vec![1i64, 1i64])),
5793 ],
5794 )
5795 .unwrap();
5796 let delta = rd.compute_delta(&[batch_nulls], &schema).unwrap();
5797 let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
5798 assert_eq!(rows, 1, "two identical NULL rows should be deduped to one");
5799
5800 let batch_diff = RecordBatch::try_new(
5802 Arc::clone(&schema),
5803 vec![
5804 Arc::new(StringArray::from(vec![None::<&str>])),
5805 Arc::new(arrow_array::Int64Array::from(vec![2i64])),
5806 ],
5807 )
5808 .unwrap();
5809 let delta2 = rd.compute_delta(&[batch_diff], &schema).unwrap();
5810 let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
5811 assert_eq!(rows2, 1, "(NULL, 2) is distinct from (NULL, 1)");
5812 }
5813
5814 #[test]
5815 fn test_row_dedup_within_candidate_dedup() {
5816 let schema = test_schema();
5818 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
5819
5820 let batch = make_batch(&["a", "a", "b"], &[1, 1, 2]);
5822 let delta = rd.compute_delta(&[batch], &schema).unwrap();
5823 let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
5824 assert_eq!(rows, 2, "within-batch dup should be collapsed: a:1, b:2");
5825 }
5826
5827 #[test]
5830 fn test_round_float_columns_near_duplicates() {
5831 let schema = Arc::new(Schema::new(vec![
5832 Field::new("name", DataType::Utf8, true),
5833 Field::new("dist", DataType::Float64, true),
5834 ]));
5835 let batch = RecordBatch::try_new(
5836 Arc::clone(&schema),
5837 vec![
5838 Arc::new(StringArray::from(vec![Some("a"), Some("a")])),
5839 Arc::new(Float64Array::from(vec![1.0000000000001, 1.0000000000002])),
5840 ],
5841 )
5842 .unwrap();
5843
5844 let rounded = round_float_columns(&[batch]);
5845 assert_eq!(rounded.len(), 1);
5846 let col = rounded[0]
5847 .column(1)
5848 .as_any()
5849 .downcast_ref::<Float64Array>()
5850 .unwrap();
5851 assert_eq!(col.value(0), col.value(1));
5853 }
5854
5855 #[test]
5858 fn test_registry_write_read_round_trip() {
5859 let schema = test_schema();
5860 let data = Arc::new(RwLock::new(Vec::new()));
5861 let mut reg = DerivedScanRegistry::new();
5862 reg.add(DerivedScanEntry {
5863 scan_index: 0,
5864 rule_name: "reachable".into(),
5865 is_self_ref: true,
5866 data: Arc::clone(&data),
5867 schema: Arc::clone(&schema),
5868 });
5869
5870 let batch = make_batch(&["x"], &[42]);
5871 reg.write_data(0, vec![batch.clone()]);
5872
5873 let entry = reg.get(0).unwrap();
5874 let guard = entry.data.read();
5875 assert_eq!(guard.len(), 1);
5876 assert_eq!(guard[0].num_rows(), 1);
5877 }
5878
5879 #[test]
5880 fn test_registry_entries_for_rule() {
5881 let schema = test_schema();
5882 let mut reg = DerivedScanRegistry::new();
5883 reg.add(DerivedScanEntry {
5884 scan_index: 0,
5885 rule_name: "r1".into(),
5886 is_self_ref: true,
5887 data: Arc::new(RwLock::new(Vec::new())),
5888 schema: Arc::clone(&schema),
5889 });
5890 reg.add(DerivedScanEntry {
5891 scan_index: 1,
5892 rule_name: "r2".into(),
5893 is_self_ref: false,
5894 data: Arc::new(RwLock::new(Vec::new())),
5895 schema: Arc::clone(&schema),
5896 });
5897 reg.add(DerivedScanEntry {
5898 scan_index: 2,
5899 rule_name: "r1".into(),
5900 is_self_ref: false,
5901 data: Arc::new(RwLock::new(Vec::new())),
5902 schema: Arc::clone(&schema),
5903 });
5904
5905 assert_eq!(reg.entries_for_rule("r1").len(), 2);
5906 assert_eq!(reg.entries_for_rule("r2").len(), 1);
5907 assert_eq!(reg.entries_for_rule("r3").len(), 0);
5908 }
5909
5910 #[test]
5913 fn test_monotonic_agg_update_and_stability() {
5914 let bindings = vec![MonotonicFoldBinding {
5915 fold_name: "total".into(),
5916 aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::SumAgg),
5917 input_col_index: 1,
5918 input_col_name: None,
5919 }];
5920 let mut agg = MonotonicAggState::new(bindings);
5921
5922 let batch = make_batch(&["a"], &[10]);
5924 agg.snapshot();
5925 let changed = agg
5926 .update(&[0], &[batch], false, SemiringKind::AddMultProb)
5927 .unwrap();
5928 assert!(changed);
5929 assert!(!agg.is_stable()); agg.snapshot();
5933 let changed = agg
5934 .update(&[0], &[], false, SemiringKind::AddMultProb)
5935 .unwrap();
5936 assert!(!changed);
5937 assert!(agg.is_stable());
5938 }
5939
5940 #[tokio::test]
5943 async fn test_memory_limit_exceeded() {
5944 let schema = test_schema();
5945 let mut state = FixpointState::new("test".into(), schema, vec![0], 1, None, false);
5947
5948 let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
5949 let result = state.merge_delta(vec![batch], None).await;
5950 assert!(result.is_err());
5951 let err = result.unwrap_err().to_string();
5952 assert!(err.contains("memory limit"), "Error was: {}", err);
5953 }
5954
5955 #[tokio::test]
5958 async fn test_fixpoint_stream_emitting() {
5959 use futures::StreamExt;
5960
5961 let schema = test_schema();
5962 let batch1 = make_batch(&["a"], &[1]);
5963 let batch2 = make_batch(&["b"], &[2]);
5964
5965 let metrics = ExecutionPlanMetricsSet::new();
5966 let baseline = BaselineMetrics::new(&metrics, 0);
5967
5968 let mut stream = FixpointStream {
5969 state: FixpointStreamState::Emitting(vec![batch1, batch2], 0),
5970 schema,
5971 metrics: baseline,
5972 };
5973
5974 let stream = Pin::new(&mut stream);
5975 let batches: Vec<RecordBatch> = stream.filter_map(|r| async { r.ok() }).collect().await;
5976
5977 assert_eq!(batches.len(), 2);
5978 assert_eq!(batches[0].num_rows(), 1);
5979 assert_eq!(batches[1].num_rows(), 1);
5980 }
5981
5982 fn make_f64_batch(names: &[&str], values: &[f64]) -> RecordBatch {
5985 let schema = Arc::new(Schema::new(vec![
5986 Field::new("name", DataType::Utf8, true),
5987 Field::new("value", DataType::Float64, true),
5988 ]));
5989 RecordBatch::try_new(
5990 schema,
5991 vec![
5992 Arc::new(StringArray::from(
5993 names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
5994 )),
5995 Arc::new(Float64Array::from(values.to_vec())),
5996 ],
5997 )
5998 .unwrap()
5999 }
6000
6001 fn make_nor_binding() -> Vec<MonotonicFoldBinding> {
6002 vec![MonotonicFoldBinding {
6003 fold_name: "prob".into(),
6004 aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::MnorAgg),
6005 input_col_index: 1,
6006 input_col_name: None,
6007 }]
6008 }
6009
6010 fn make_prod_binding() -> Vec<MonotonicFoldBinding> {
6011 vec![MonotonicFoldBinding {
6012 fold_name: "prob".into(),
6013 aggregate: std::sync::Arc::new(uni_plugin_builtin::locy_aggregates::MprodAgg),
6014 input_col_index: 1,
6015 input_col_name: None,
6016 }]
6017 }
6018
6019 fn acc_key(name: &str) -> (Vec<ScalarKey>, String) {
6020 (vec![ScalarKey::Utf8(name.to_string())], "prob".to_string())
6021 }
6022
6023 #[test]
6024 fn test_monotonic_nor_first_update() {
6025 let mut agg = MonotonicAggState::new(make_nor_binding());
6026 let batch = make_f64_batch(&["a"], &[0.3]);
6027 let changed = agg
6028 .update(&[0], &[batch], false, SemiringKind::AddMultProb)
6029 .unwrap();
6030 assert!(changed);
6031 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6032 assert!((val - 0.3).abs() < 1e-10, "expected 0.3, got {}", val);
6033 }
6034
6035 #[test]
6036 fn test_monotonic_nor_two_updates() {
6037 let mut agg = MonotonicAggState::new(make_nor_binding());
6039 let batch1 = make_f64_batch(&["a"], &[0.3]);
6040 agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6041 .unwrap();
6042 let batch2 = make_f64_batch(&["a"], &[0.5]);
6043 agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6044 .unwrap();
6045 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6046 assert!((val - 0.65).abs() < 1e-10, "expected 0.65, got {}", val);
6047 }
6048
6049 #[test]
6050 fn test_monotonic_prod_first_update() {
6051 let mut agg = MonotonicAggState::new(make_prod_binding());
6052 let batch = make_f64_batch(&["a"], &[0.6]);
6053 let changed = agg
6054 .update(&[0], &[batch], false, SemiringKind::AddMultProb)
6055 .unwrap();
6056 assert!(changed);
6057 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6058 assert!((val - 0.6).abs() < 1e-10, "expected 0.6, got {}", val);
6059 }
6060
6061 #[test]
6062 fn test_monotonic_prod_two_updates() {
6063 let mut agg = MonotonicAggState::new(make_prod_binding());
6065 let batch1 = make_f64_batch(&["a"], &[0.6]);
6066 agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6067 .unwrap();
6068 let batch2 = make_f64_batch(&["a"], &[0.8]);
6069 agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6070 .unwrap();
6071 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6072 assert!((val - 0.48).abs() < 1e-10, "expected 0.48, got {}", val);
6073 }
6074
6075 #[test]
6076 fn test_monotonic_nor_stability() {
6077 let mut agg = MonotonicAggState::new(make_nor_binding());
6078 let batch = make_f64_batch(&["a"], &[0.3]);
6079 agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6080 .unwrap();
6081 agg.snapshot();
6082 let changed = agg
6083 .update(&[0], &[], false, SemiringKind::AddMultProb)
6084 .unwrap();
6085 assert!(!changed);
6086 assert!(agg.is_stable());
6087 }
6088
6089 #[test]
6090 fn test_monotonic_prod_stability() {
6091 let mut agg = MonotonicAggState::new(make_prod_binding());
6092 let batch = make_f64_batch(&["a"], &[0.6]);
6093 agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6094 .unwrap();
6095 agg.snapshot();
6096 let changed = agg
6097 .update(&[0], &[], false, SemiringKind::AddMultProb)
6098 .unwrap();
6099 assert!(!changed);
6100 assert!(agg.is_stable());
6101 }
6102
6103 #[test]
6104 fn test_monotonic_nor_multi_group() {
6105 let mut agg = MonotonicAggState::new(make_nor_binding());
6107 let batch1 = make_f64_batch(&["a", "b"], &[0.3, 0.5]);
6108 agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6109 .unwrap();
6110 let batch2 = make_f64_batch(&["a", "b"], &[0.5, 0.2]);
6111 agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6112 .unwrap();
6113
6114 let val_a = agg.get_accumulator(&acc_key("a")).unwrap();
6115 let val_b = agg.get_accumulator(&acc_key("b")).unwrap();
6116 assert!(
6117 (val_a - 0.65).abs() < 1e-10,
6118 "expected a=0.65, got {}",
6119 val_a
6120 );
6121 assert!((val_b - 0.6).abs() < 1e-10, "expected b=0.6, got {}", val_b);
6122 }
6123
6124 #[test]
6125 fn test_monotonic_prod_zero_absorbing() {
6126 let mut agg = MonotonicAggState::new(make_prod_binding());
6128 let batch1 = make_f64_batch(&["a"], &[0.5]);
6129 agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6130 .unwrap();
6131 let batch2 = make_f64_batch(&["a"], &[0.0]);
6132 agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6133 .unwrap();
6134
6135 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6136 assert!((val - 0.0).abs() < 1e-10, "expected 0.0, got {}", val);
6137
6138 agg.snapshot();
6140 let batch3 = make_f64_batch(&["a"], &[0.5]);
6141 let changed = agg
6142 .update(&[0], &[batch3], false, SemiringKind::AddMultProb)
6143 .unwrap();
6144 assert!(!changed);
6145 assert!(agg.is_stable());
6146 }
6147
6148 #[test]
6149 fn test_monotonic_nor_clamping() {
6150 let mut agg = MonotonicAggState::new(make_nor_binding());
6152 let batch = make_f64_batch(&["a"], &[1.5]);
6153 agg.update(&[0], &[batch], false, SemiringKind::AddMultProb)
6154 .unwrap();
6155 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6156 assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
6157 }
6158
6159 #[test]
6160 fn test_monotonic_nor_absorbing() {
6161 let mut agg = MonotonicAggState::new(make_nor_binding());
6163 let batch1 = make_f64_batch(&["a"], &[0.3]);
6164 agg.update(&[0], &[batch1], false, SemiringKind::AddMultProb)
6165 .unwrap();
6166 let batch2 = make_f64_batch(&["a"], &[1.0]);
6167 agg.update(&[0], &[batch2], false, SemiringKind::AddMultProb)
6168 .unwrap();
6169 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6170 assert!((val - 1.0).abs() < 1e-10, "expected 1.0, got {}", val);
6171 }
6172
6173 #[test]
6176 fn test_monotonic_agg_strict_nor_rejects() {
6177 let mut agg = MonotonicAggState::new(make_nor_binding());
6178 let batch = make_f64_batch(&["a"], &[1.5]);
6179 let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6180 assert!(result.is_err());
6181 let err = result.unwrap_err().to_string();
6182 assert!(
6183 err.contains("strict_probability_domain"),
6184 "Expected strict error, got: {}",
6185 err
6186 );
6187 }
6188
6189 #[test]
6190 fn test_monotonic_agg_strict_prod_rejects() {
6191 let mut agg = MonotonicAggState::new(make_prod_binding());
6192 let batch = make_f64_batch(&["a"], &[2.0]);
6193 let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6194 assert!(result.is_err());
6195 let err = result.unwrap_err().to_string();
6196 assert!(
6197 err.contains("strict_probability_domain"),
6198 "Expected strict error, got: {}",
6199 err
6200 );
6201 }
6202
6203 #[test]
6204 fn test_monotonic_agg_strict_accepts_valid() {
6205 let mut agg = MonotonicAggState::new(make_nor_binding());
6206 let batch = make_f64_batch(&["a"], &[0.5]);
6207 let result = agg.update(&[0], &[batch], true, SemiringKind::AddMultProb);
6208 assert!(result.is_ok());
6209 let val = agg.get_accumulator(&acc_key("a")).unwrap();
6210 assert!((val - 0.5).abs() < 1e-10, "expected 0.5, got {}", val);
6211 }
6212
6213 fn make_vid_prob_batch(vids: &[u64], probs: &[f64]) -> RecordBatch {
6216 use arrow_array::UInt64Array;
6217 let schema = Arc::new(Schema::new(vec![
6218 Field::new("vid", DataType::UInt64, true),
6219 Field::new("prob", DataType::Float64, true),
6220 ]));
6221 RecordBatch::try_new(
6222 schema,
6223 vec![
6224 Arc::new(UInt64Array::from(vids.to_vec())),
6225 Arc::new(Float64Array::from(probs.to_vec())),
6226 ],
6227 )
6228 .unwrap()
6229 }
6230
6231 #[test]
6232 fn test_prob_complement_basic() {
6233 let body = make_vid_prob_batch(&[1, 2], &[0.9, 0.8]);
6235 let neg = make_vid_prob_batch(&[1], &[0.7]);
6236 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6237 let result = apply_prob_complement_composite(
6238 vec![body],
6239 &[neg],
6240 &join_cols,
6241 "prob",
6242 "__complement_0",
6243 )
6244 .unwrap();
6245 assert_eq!(result.len(), 1);
6246 let batch = &result[0];
6247 let complement = batch
6248 .column_by_name("__complement_0")
6249 .unwrap()
6250 .as_any()
6251 .downcast_ref::<Float64Array>()
6252 .unwrap();
6253 assert!(
6255 (complement.value(0) - 0.3).abs() < 1e-10,
6256 "expected 0.3, got {}",
6257 complement.value(0)
6258 );
6259 assert!(
6261 (complement.value(1) - 1.0).abs() < 1e-10,
6262 "expected 1.0, got {}",
6263 complement.value(1)
6264 );
6265 }
6266
6267 #[test]
6268 fn test_prob_complement_noisy_or_duplicates() {
6269 let body = make_vid_prob_batch(&[1], &[0.9]);
6273 let neg = make_vid_prob_batch(&[1, 1], &[0.3, 0.5]);
6274 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6275 let result = apply_prob_complement_composite(
6276 vec![body],
6277 &[neg],
6278 &join_cols,
6279 "prob",
6280 "__complement_0",
6281 )
6282 .unwrap();
6283 let batch = &result[0];
6284 let complement = batch
6285 .column_by_name("__complement_0")
6286 .unwrap()
6287 .as_any()
6288 .downcast_ref::<Float64Array>()
6289 .unwrap();
6290 assert!(
6291 (complement.value(0) - 0.35).abs() < 1e-10,
6292 "expected 0.35, got {}",
6293 complement.value(0)
6294 );
6295 }
6296
6297 #[test]
6298 fn test_prob_complement_empty_neg() {
6299 let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
6301 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6302 let result =
6303 apply_prob_complement_composite(vec![body], &[], &join_cols, "prob", "__complement_0")
6304 .unwrap();
6305 let batch = &result[0];
6306 let complement = batch
6307 .column_by_name("__complement_0")
6308 .unwrap()
6309 .as_any()
6310 .downcast_ref::<Float64Array>()
6311 .unwrap();
6312 for i in 0..2 {
6313 assert!(
6314 (complement.value(i) - 1.0).abs() < 1e-10,
6315 "row {}: expected 1.0, got {}",
6316 i,
6317 complement.value(i)
6318 );
6319 }
6320 }
6321
6322 #[test]
6323 fn test_anti_join_basic() {
6324 use arrow_array::UInt64Array;
6326 let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
6327 let neg = make_vid_prob_batch(&[2], &[0.0]);
6328 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6329 let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
6330 assert_eq!(result.len(), 1);
6331 let batch = &result[0];
6332 assert_eq!(batch.num_rows(), 2);
6333 let vids = batch
6334 .column_by_name("vid")
6335 .unwrap()
6336 .as_any()
6337 .downcast_ref::<UInt64Array>()
6338 .unwrap();
6339 assert_eq!(vids.value(0), 1);
6340 assert_eq!(vids.value(1), 3);
6341 }
6342
6343 #[test]
6344 fn test_anti_join_empty_neg() {
6345 let body = make_vid_prob_batch(&[1, 2, 3], &[0.5, 0.6, 0.7]);
6347 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6348 let result = apply_anti_join_composite(vec![body], &[], &join_cols).unwrap();
6349 assert_eq!(result.len(), 1);
6350 assert_eq!(result[0].num_rows(), 3);
6351 }
6352
6353 #[test]
6354 fn test_anti_join_all_excluded() {
6355 let body = make_vid_prob_batch(&[1, 2], &[0.5, 0.6]);
6357 let neg = make_vid_prob_batch(&[1, 2], &[0.0, 0.0]);
6358 let join_cols = vec![("vid".to_string(), "vid".to_string())];
6359 let result = apply_anti_join_composite(vec![body], &[neg], &join_cols).unwrap();
6360 let total: usize = result.iter().map(|b| b.num_rows()).sum();
6361 assert_eq!(total, 0);
6362 }
6363
6364 #[test]
6365 fn test_multiply_prob_single_complement() {
6366 let body = make_vid_prob_batch(&[1], &[0.8]);
6368 let complement_arr = Float64Array::from(vec![0.5]);
6370 let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
6371 cols.push(Arc::new(complement_arr));
6372 let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
6373 fields.push(Arc::new(Field::new(
6374 "__complement_0",
6375 DataType::Float64,
6376 true,
6377 )));
6378 let schema = Arc::new(Schema::new(fields));
6379 let batch = RecordBatch::try_new(schema, cols).unwrap();
6380
6381 let result =
6382 multiply_prob_factors(vec![batch], Some("prob"), &["__complement_0".to_string()])
6383 .unwrap();
6384 assert_eq!(result.len(), 1);
6385 let out = &result[0];
6386 assert!(out.column_by_name("__complement_0").is_none());
6388 let prob = out
6389 .column_by_name("prob")
6390 .unwrap()
6391 .as_any()
6392 .downcast_ref::<Float64Array>()
6393 .unwrap();
6394 assert!(
6395 (prob.value(0) - 0.4).abs() < 1e-10,
6396 "expected 0.4, got {}",
6397 prob.value(0)
6398 );
6399 }
6400
6401 #[test]
6402 fn test_multiply_prob_multiple_complements() {
6403 let body = make_vid_prob_batch(&[1], &[0.8]);
6405 let c1 = Float64Array::from(vec![0.5]);
6406 let c2 = Float64Array::from(vec![0.6]);
6407 let mut cols: Vec<arrow_array::ArrayRef> = body.columns().to_vec();
6408 cols.push(Arc::new(c1));
6409 cols.push(Arc::new(c2));
6410 let mut fields: Vec<Arc<Field>> = body.schema().fields().iter().cloned().collect();
6411 fields.push(Arc::new(Field::new("__c1", DataType::Float64, true)));
6412 fields.push(Arc::new(Field::new("__c2", DataType::Float64, true)));
6413 let schema = Arc::new(Schema::new(fields));
6414 let batch = RecordBatch::try_new(schema, cols).unwrap();
6415
6416 let result = multiply_prob_factors(
6417 vec![batch],
6418 Some("prob"),
6419 &["__c1".to_string(), "__c2".to_string()],
6420 )
6421 .unwrap();
6422 let out = &result[0];
6423 assert!(out.column_by_name("__c1").is_none());
6424 assert!(out.column_by_name("__c2").is_none());
6425 let prob = out
6426 .column_by_name("prob")
6427 .unwrap()
6428 .as_any()
6429 .downcast_ref::<Float64Array>()
6430 .unwrap();
6431 assert!(
6432 (prob.value(0) - 0.24).abs() < 1e-10,
6433 "expected 0.24, got {}",
6434 prob.value(0)
6435 );
6436 }
6437
6438 #[test]
6439 fn test_multiply_prob_no_prob_column() {
6440 use arrow_array::UInt64Array;
6442 let schema = Arc::new(Schema::new(vec![
6443 Field::new("vid", DataType::UInt64, true),
6444 Field::new("__c1", DataType::Float64, true),
6445 ]));
6446 let batch = RecordBatch::try_new(
6447 schema,
6448 vec![
6449 Arc::new(UInt64Array::from(vec![1u64])),
6450 Arc::new(Float64Array::from(vec![0.7])),
6451 ],
6452 )
6453 .unwrap();
6454
6455 let result = multiply_prob_factors(vec![batch], None, &["__c1".to_string()]).unwrap();
6456 let out = &result[0];
6457 assert!(out.column_by_name("__c1").is_none());
6459 assert_eq!(out.num_columns(), 1);
6461 }
6462}