1use crate::query::df_graph::GraphExecutionContext;
10use crate::query::df_graph::common::{
11 ScalarKey, collect_all_partitions, compute_plan_properties, execute_subplan, extract_scalar_key,
12};
13use crate::query::df_graph::locy_best_by::{BestByExec, SortCriterion};
14use crate::query::df_graph::locy_errors::LocyRuntimeError;
15use crate::query::df_graph::locy_explain::{DerivationEntry, DerivationTracker};
16use crate::query::df_graph::locy_fold::{FoldBinding, FoldExec};
17use crate::query::df_graph::locy_priority::PriorityExec;
18use crate::query::planner::LogicalPlan;
19use arrow_array::RecordBatch;
20use arrow_row::{RowConverter, SortField};
21use arrow_schema::SchemaRef;
22use datafusion::common::JoinType;
23use datafusion::common::Result as DFResult;
24use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
25use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
26use datafusion::physical_plan::memory::MemoryStream;
27use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
28use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
29use futures::Stream;
30use parking_lot::RwLock;
31use std::any::Any;
32use std::collections::{HashMap, HashSet};
33use std::fmt;
34use std::pin::Pin;
35use std::sync::{Arc, RwLock as StdRwLock};
36use std::task::{Context, Poll};
37use std::time::{Duration, Instant};
38use uni_common::Value;
39use uni_common::core::schema::Schema as UniSchema;
40use uni_store::storage::manager::StorageManager;
41
42#[derive(Debug)]
52pub struct DerivedScanEntry {
53 pub scan_index: usize,
55 pub rule_name: String,
57 pub is_self_ref: bool,
59 pub data: Arc<RwLock<Vec<RecordBatch>>>,
61 pub schema: SchemaRef,
63}
64
65#[derive(Debug, Default)]
72pub struct DerivedScanRegistry {
73 entries: Vec<DerivedScanEntry>,
74}
75
76impl DerivedScanRegistry {
77 pub fn new() -> Self {
79 Self::default()
80 }
81
82 pub fn add(&mut self, entry: DerivedScanEntry) {
84 self.entries.push(entry);
85 }
86
87 pub fn get(&self, scan_index: usize) -> Option<&DerivedScanEntry> {
89 self.entries.iter().find(|e| e.scan_index == scan_index)
90 }
91
92 pub fn write_data(&self, scan_index: usize, batches: Vec<RecordBatch>) {
94 if let Some(entry) = self.get(scan_index) {
95 let mut guard = entry.data.write();
96 *guard = batches;
97 }
98 }
99
100 pub fn entries_for_rule(&self, rule_name: &str) -> Vec<&DerivedScanEntry> {
102 self.entries
103 .iter()
104 .filter(|e| e.rule_name == rule_name)
105 .collect()
106 }
107}
108
109#[derive(Debug, Clone)]
115pub struct MonotonicFoldBinding {
116 pub fold_name: String,
117 pub kind: crate::query::df_graph::locy_fold::FoldAggKind,
118 pub input_col_index: usize,
119}
120
121#[derive(Debug)]
127pub struct MonotonicAggState {
128 accumulators: HashMap<(Vec<ScalarKey>, String), f64>,
130 prev_snapshot: HashMap<(Vec<ScalarKey>, String), f64>,
132 bindings: Vec<MonotonicFoldBinding>,
134}
135
136impl MonotonicAggState {
137 pub fn new(bindings: Vec<MonotonicFoldBinding>) -> Self {
139 Self {
140 accumulators: HashMap::new(),
141 prev_snapshot: HashMap::new(),
142 bindings,
143 }
144 }
145
146 pub fn update(&mut self, key_indices: &[usize], delta_batches: &[RecordBatch]) -> bool {
148 use crate::query::df_graph::locy_fold::FoldAggKind;
149
150 let mut changed = false;
151 for batch in delta_batches {
152 for row_idx in 0..batch.num_rows() {
153 let group_key = extract_scalar_key(batch, key_indices, row_idx);
154 for binding in &self.bindings {
155 let col = batch.column(binding.input_col_index);
156 let val = extract_f64(col.as_ref(), row_idx);
157 if let Some(val) = val {
158 let map_key = (group_key.clone(), binding.fold_name.clone());
159 let entry =
160 self.accumulators
161 .entry(map_key)
162 .or_insert(match binding.kind {
163 FoldAggKind::Sum | FoldAggKind::Count | FoldAggKind::Avg => 0.0,
164 FoldAggKind::Max => f64::NEG_INFINITY,
165 FoldAggKind::Min => f64::INFINITY,
166 FoldAggKind::Collect => 0.0,
167 });
168 let old = *entry;
169 match binding.kind {
170 FoldAggKind::Sum | FoldAggKind::Count => *entry += val,
171 FoldAggKind::Max => {
172 if val > *entry {
173 *entry = val;
174 }
175 }
176 FoldAggKind::Min => {
177 if val < *entry {
178 *entry = val;
179 }
180 }
181 _ => {}
182 }
183 if (*entry - old).abs() > f64::EPSILON {
184 changed = true;
185 }
186 }
187 }
188 }
189 }
190 changed
191 }
192
193 pub fn snapshot(&mut self) {
195 self.prev_snapshot = self.accumulators.clone();
196 }
197
198 pub fn is_stable(&self) -> bool {
200 if self.accumulators.len() != self.prev_snapshot.len() {
201 return false;
202 }
203 for (key, val) in &self.accumulators {
204 match self.prev_snapshot.get(key) {
205 Some(prev) if (*val - *prev).abs() <= f64::EPSILON => {}
206 _ => return false,
207 }
208 }
209 true
210 }
211}
212
213fn extract_f64(col: &dyn arrow_array::Array, row_idx: usize) -> Option<f64> {
215 if col.is_null(row_idx) {
216 return None;
217 }
218 if let Some(arr) = col.as_any().downcast_ref::<arrow_array::Float64Array>() {
219 Some(arr.value(row_idx))
220 } else {
221 col.as_any()
222 .downcast_ref::<arrow_array::Int64Array>()
223 .map(|arr| arr.value(row_idx) as f64)
224 }
225}
226
227struct RowDedupState {
237 converter: RowConverter,
238 seen: HashSet<Box<[u8]>>,
239}
240
241impl RowDedupState {
242 fn try_new(schema: &SchemaRef) -> Option<Self> {
247 let fields: Vec<SortField> = schema
248 .fields()
249 .iter()
250 .map(|f| SortField::new(f.data_type().clone()))
251 .collect();
252 match RowConverter::new(fields) {
253 Ok(converter) => Some(Self {
254 converter,
255 seen: HashSet::new(),
256 }),
257 Err(e) => {
258 tracing::warn!(
259 "RowDedupState: RowConverter unsupported for schema, falling back to legacy dedup: {}",
260 e
261 );
262 None
263 }
264 }
265 }
266
267 fn compute_delta(
273 &mut self,
274 candidates: &[RecordBatch],
275 schema: &SchemaRef,
276 ) -> DFResult<Vec<RecordBatch>> {
277 let mut delta_batches = Vec::new();
278 for batch in candidates {
279 if batch.num_rows() == 0 {
280 continue;
281 }
282
283 let arrays: Vec<_> = batch.columns().to_vec();
285 let rows = self
286 .converter
287 .convert_columns(&arrays)
288 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
289
290 let mut keep = Vec::with_capacity(batch.num_rows());
292 for row_idx in 0..batch.num_rows() {
293 let row_bytes: Box<[u8]> = rows.row(row_idx).data().into();
294 keep.push(self.seen.insert(row_bytes));
295 }
296
297 let keep_mask = arrow_array::BooleanArray::from(keep);
298 let new_cols = batch
299 .columns()
300 .iter()
301 .map(|col| {
302 arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
303 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
304 })
305 })
306 .collect::<DFResult<Vec<_>>>()?;
307
308 if new_cols.first().is_some_and(|c| !c.is_empty()) {
309 let filtered = RecordBatch::try_new(Arc::clone(schema), new_cols).map_err(|e| {
310 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
311 })?;
312 delta_batches.push(filtered);
313 }
314 }
315 Ok(delta_batches)
316 }
317}
318
319pub struct FixpointState {
329 rule_name: String,
330 facts: Vec<RecordBatch>,
331 delta: Vec<RecordBatch>,
332 schema: SchemaRef,
333 key_column_indices: Vec<usize>,
334 all_column_indices: Vec<usize>,
336 facts_bytes: usize,
338 max_derived_bytes: usize,
340 monotonic_agg: Option<MonotonicAggState>,
342 row_dedup: Option<RowDedupState>,
344}
345
346impl FixpointState {
347 pub fn new(
349 rule_name: String,
350 schema: SchemaRef,
351 key_column_indices: Vec<usize>,
352 max_derived_bytes: usize,
353 monotonic_agg: Option<MonotonicAggState>,
354 ) -> Self {
355 let num_cols = schema.fields().len();
356 let row_dedup = RowDedupState::try_new(&schema);
357 Self {
358 rule_name,
359 facts: Vec::new(),
360 delta: Vec::new(),
361 schema,
362 key_column_indices,
363 all_column_indices: (0..num_cols).collect(),
364 facts_bytes: 0,
365 max_derived_bytes,
366 monotonic_agg,
367 row_dedup,
368 }
369 }
370
371 pub async fn merge_delta(
375 &mut self,
376 candidates: Vec<RecordBatch>,
377 task_ctx: Option<Arc<TaskContext>>,
378 ) -> DFResult<bool> {
379 if candidates.is_empty() || candidates.iter().all(|b| b.num_rows() == 0) {
380 self.delta.clear();
381 return Ok(false);
382 }
383
384 let candidates = round_float_columns(&candidates);
386
387 let delta = self.compute_delta(&candidates, task_ctx.as_ref()).await?;
389
390 if delta.is_empty() || delta.iter().all(|b| b.num_rows() == 0) {
391 self.delta.clear();
392 if let Some(ref mut agg) = self.monotonic_agg {
394 agg.snapshot();
395 }
396 return Ok(false);
397 }
398
399 let delta_bytes: usize = delta.iter().map(batch_byte_size).sum();
401 if self.facts_bytes + delta_bytes > self.max_derived_bytes {
402 return Err(datafusion::error::DataFusionError::Execution(
403 LocyRuntimeError::MemoryLimitExceeded {
404 rule: self.rule_name.clone(),
405 bytes: self.facts_bytes + delta_bytes,
406 limit: self.max_derived_bytes,
407 }
408 .to_string(),
409 ));
410 }
411
412 if let Some(ref mut agg) = self.monotonic_agg {
414 agg.snapshot();
415 agg.update(&self.key_column_indices, &delta);
416 }
417
418 self.facts_bytes += delta_bytes;
420 self.facts.extend(delta.iter().cloned());
421 self.delta = delta;
422
423 Ok(true)
424 }
425
426 async fn compute_delta(
433 &mut self,
434 candidates: &[RecordBatch],
435 task_ctx: Option<&Arc<TaskContext>>,
436 ) -> DFResult<Vec<RecordBatch>> {
437 let total_existing: usize = self.facts.iter().map(|b| b.num_rows()).sum();
438 if total_existing >= DEDUP_ANTI_JOIN_THRESHOLD
439 && let Some(ctx) = task_ctx
440 {
441 return arrow_left_anti_dedup(candidates.to_vec(), &self.facts, &self.schema, ctx)
442 .await;
443 }
444 if let Some(ref mut rd) = self.row_dedup {
445 rd.compute_delta(candidates, &self.schema)
446 } else {
447 self.compute_delta_legacy(candidates)
448 }
449 }
450
451 fn compute_delta_legacy(&self, candidates: &[RecordBatch]) -> DFResult<Vec<RecordBatch>> {
455 let mut existing: HashSet<Vec<ScalarKey>> = HashSet::new();
457 for batch in &self.facts {
458 for row_idx in 0..batch.num_rows() {
459 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
460 existing.insert(key);
461 }
462 }
463
464 let mut delta_batches = Vec::new();
465 for batch in candidates {
466 if batch.num_rows() == 0 {
467 continue;
468 }
469 let mut keep = Vec::with_capacity(batch.num_rows());
471 for row_idx in 0..batch.num_rows() {
472 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
473 keep.push(!existing.contains(&key));
474 }
475
476 for (row_idx, kept) in keep.iter_mut().enumerate() {
478 if *kept {
479 let key = extract_scalar_key(batch, &self.all_column_indices, row_idx);
480 if !existing.insert(key) {
481 *kept = false;
482 }
483 }
484 }
485
486 let keep_mask = arrow_array::BooleanArray::from(keep);
487 let new_rows = batch
488 .columns()
489 .iter()
490 .map(|col| {
491 arrow::compute::filter(col.as_ref(), &keep_mask).map_err(|e| {
492 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
493 })
494 })
495 .collect::<DFResult<Vec<_>>>()?;
496
497 if new_rows.first().is_some_and(|c| !c.is_empty()) {
498 let filtered =
499 RecordBatch::try_new(Arc::clone(&self.schema), new_rows).map_err(|e| {
500 datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
501 })?;
502 delta_batches.push(filtered);
503 }
504 }
505
506 Ok(delta_batches)
507 }
508
509 pub fn is_converged(&self) -> bool {
511 let delta_empty = self.delta.is_empty() || self.delta.iter().all(|b| b.num_rows() == 0);
512 let agg_stable = self.monotonic_agg.as_ref().is_none_or(|a| a.is_stable());
513 delta_empty && agg_stable
514 }
515
516 pub fn all_facts(&self) -> &[RecordBatch] {
518 &self.facts
519 }
520
521 pub fn all_delta(&self) -> &[RecordBatch] {
523 &self.delta
524 }
525
526 pub fn into_facts(self) -> Vec<RecordBatch> {
528 self.facts
529 }
530}
531
532fn batch_byte_size(batch: &RecordBatch) -> usize {
534 batch
535 .columns()
536 .iter()
537 .map(|col| col.get_buffer_memory_size())
538 .sum()
539}
540
541fn round_float_columns(batches: &[RecordBatch]) -> Vec<RecordBatch> {
547 batches
548 .iter()
549 .map(|batch| {
550 let schema = batch.schema();
551 let has_float = schema
552 .fields()
553 .iter()
554 .any(|f| *f.data_type() == arrow_schema::DataType::Float64);
555 if !has_float {
556 return batch.clone();
557 }
558
559 let columns: Vec<arrow_array::ArrayRef> = batch
560 .columns()
561 .iter()
562 .enumerate()
563 .map(|(i, col)| {
564 if *schema.field(i).data_type() == arrow_schema::DataType::Float64 {
565 let arr = col
566 .as_any()
567 .downcast_ref::<arrow_array::Float64Array>()
568 .unwrap();
569 let rounded: arrow_array::Float64Array = arr
570 .iter()
571 .map(|v| v.map(|f| (f * 1e12).round() / 1e12))
572 .collect();
573 Arc::new(rounded) as arrow_array::ArrayRef
574 } else {
575 Arc::clone(col)
576 }
577 })
578 .collect();
579
580 RecordBatch::try_new(schema, columns).unwrap_or_else(|_| batch.clone())
581 })
582 .collect()
583}
584
585const DEDUP_ANTI_JOIN_THRESHOLD: usize = 300;
595
596async fn arrow_left_anti_dedup(
601 candidates: Vec<RecordBatch>,
602 existing: &[RecordBatch],
603 schema: &SchemaRef,
604 task_ctx: &Arc<TaskContext>,
605) -> DFResult<Vec<RecordBatch>> {
606 if existing.is_empty() || existing.iter().all(|b| b.num_rows() == 0) {
607 return Ok(candidates);
608 }
609
610 let left: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(candidates, Arc::clone(schema)));
611 let right: Arc<dyn ExecutionPlan> =
612 Arc::new(InMemoryExec::new(existing.to_vec(), Arc::clone(schema)));
613
614 let on: Vec<(
615 Arc<dyn datafusion::physical_plan::PhysicalExpr>,
616 Arc<dyn datafusion::physical_plan::PhysicalExpr>,
617 )> = schema
618 .fields()
619 .iter()
620 .enumerate()
621 .map(|(i, field)| {
622 let l: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
623 datafusion::physical_plan::expressions::Column::new(field.name(), i),
624 );
625 let r: Arc<dyn datafusion::physical_plan::PhysicalExpr> = Arc::new(
626 datafusion::physical_plan::expressions::Column::new(field.name(), i),
627 );
628 (l, r)
629 })
630 .collect();
631
632 if on.is_empty() {
633 return Ok(vec![]);
634 }
635
636 let join = HashJoinExec::try_new(
637 left,
638 right,
639 on,
640 None,
641 &JoinType::LeftAnti,
642 None,
643 PartitionMode::CollectLeft,
644 datafusion::common::NullEquality::NullEqualsNull,
645 )?;
646
647 let join_arc: Arc<dyn ExecutionPlan> = Arc::new(join);
648 collect_all_partitions(&join_arc, task_ctx.clone()).await
649}
650
651#[derive(Debug, Clone)]
657pub struct IsRefBinding {
658 pub derived_scan_index: usize,
660 pub rule_name: String,
662 pub is_self_ref: bool,
664 pub negated: bool,
666 pub anti_join_cols: Vec<(String, String)>,
672}
673
674#[derive(Debug)]
676pub struct FixpointClausePlan {
677 pub body_logical: LogicalPlan,
679 pub is_ref_bindings: Vec<IsRefBinding>,
681 pub priority: Option<i64>,
683}
684
685#[derive(Debug)]
687pub struct FixpointRulePlan {
688 pub name: String,
690 pub clauses: Vec<FixpointClausePlan>,
692 pub yield_schema: SchemaRef,
694 pub key_column_indices: Vec<usize>,
696 pub priority: Option<i64>,
698 pub has_fold: bool,
700 pub fold_bindings: Vec<FoldBinding>,
702 pub has_best_by: bool,
704 pub best_by_criteria: Vec<SortCriterion>,
706 pub has_priority: bool,
708 pub deterministic: bool,
712}
713
714#[allow(clippy::too_many_arguments)]
723async fn run_fixpoint_loop(
724 rules: Vec<FixpointRulePlan>,
725 max_iterations: usize,
726 timeout: Duration,
727 graph_ctx: Arc<GraphExecutionContext>,
728 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
729 storage: Arc<StorageManager>,
730 schema_info: Arc<UniSchema>,
731 params: HashMap<String, Value>,
732 registry: Arc<DerivedScanRegistry>,
733 output_schema: SchemaRef,
734 max_derived_bytes: usize,
735 derivation_tracker: Option<Arc<DerivationTracker>>,
736 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
737) -> DFResult<Vec<RecordBatch>> {
738 let start = Instant::now();
739 let task_ctx = session_ctx.read().task_ctx();
740
741 let mut states: Vec<FixpointState> = rules
743 .iter()
744 .map(|rule| {
745 let monotonic_agg = if !rule.fold_bindings.is_empty() {
746 let bindings: Vec<MonotonicFoldBinding> = rule
747 .fold_bindings
748 .iter()
749 .map(|fb| MonotonicFoldBinding {
750 fold_name: fb.output_name.clone(),
751 kind: fb.kind.clone(),
752 input_col_index: fb.input_col_index,
753 })
754 .collect();
755 Some(MonotonicAggState::new(bindings))
756 } else {
757 None
758 };
759 FixpointState::new(
760 rule.name.clone(),
761 Arc::clone(&rule.yield_schema),
762 rule.key_column_indices.clone(),
763 max_derived_bytes,
764 monotonic_agg,
765 )
766 })
767 .collect();
768
769 let mut converged = false;
771 let mut total_iters = 0usize;
772 for iteration in 0..max_iterations {
773 total_iters = iteration + 1;
774 tracing::debug!("fixpoint iteration {}", iteration);
775 let mut any_changed = false;
776
777 for rule_idx in 0..rules.len() {
778 let rule = &rules[rule_idx];
779
780 update_derived_scan_handles(®istry, &states, rule_idx, &rules);
782
783 let mut all_candidates = Vec::new();
785 let mut clause_candidates: Vec<Vec<RecordBatch>> = Vec::new();
786 for clause in &rule.clauses {
787 let mut batches = execute_subplan(
788 &clause.body_logical,
789 ¶ms,
790 &HashMap::new(),
791 &graph_ctx,
792 &session_ctx,
793 &storage,
794 &schema_info,
795 )
796 .await?;
797 for binding in &clause.is_ref_bindings {
799 if binding.negated
800 && !binding.anti_join_cols.is_empty()
801 && let Some(entry) = registry.get(binding.derived_scan_index)
802 {
803 let neg_facts = entry.data.read().clone();
804 if !neg_facts.is_empty() {
805 for (left_col, right_col) in &binding.anti_join_cols {
806 batches =
807 apply_anti_join(batches, &neg_facts, left_col, right_col)?;
808 }
809 }
810 }
811 }
812 clause_candidates.push(batches.clone());
813 all_candidates.extend(batches);
814 }
815
816 let changed = states[rule_idx]
818 .merge_delta(all_candidates, Some(Arc::clone(&task_ctx)))
819 .await?;
820 if changed {
821 any_changed = true;
822 if let Some(ref tracker) = derivation_tracker {
824 record_provenance(
825 tracker,
826 rule,
827 &states[rule_idx],
828 &clause_candidates,
829 iteration,
830 );
831 }
832 }
833 }
834
835 if !any_changed && states.iter().all(|s| s.is_converged()) {
837 tracing::debug!("fixpoint converged after {} iterations", iteration + 1);
838 converged = true;
839 break;
840 }
841
842 if start.elapsed() > timeout {
844 return Err(datafusion::error::DataFusionError::Execution(
845 LocyRuntimeError::NonConvergence {
846 iterations: iteration + 1,
847 }
848 .to_string(),
849 ));
850 }
851 }
852
853 if let Ok(mut counts) = iteration_counts.write() {
855 for rule in &rules {
856 counts.insert(rule.name.clone(), total_iters);
857 }
858 }
859
860 if !converged {
862 return Err(datafusion::error::DataFusionError::Execution(
863 LocyRuntimeError::NonConvergence {
864 iterations: max_iterations,
865 }
866 .to_string(),
867 ));
868 }
869
870 let task_ctx = session_ctx.read().task_ctx();
872 let mut all_output = Vec::new();
873
874 for (rule_idx, state) in states.into_iter().enumerate() {
875 let rule = &rules[rule_idx];
876 let facts = state.into_facts();
877 if facts.is_empty() {
878 continue;
879 }
880
881 let processed = apply_post_fixpoint_chain(facts, rule, &task_ctx).await?;
882 all_output.extend(processed);
883 }
884
885 if all_output.is_empty() {
887 all_output.push(RecordBatch::new_empty(output_schema));
888 }
889
890 Ok(all_output)
891}
892
893fn record_provenance(
902 tracker: &Arc<DerivationTracker>,
903 rule: &FixpointRulePlan,
904 state: &FixpointState,
905 clause_candidates: &[Vec<RecordBatch>],
906 iteration: usize,
907) {
908 let all_indices: Vec<usize> = (0..rule.yield_schema.fields().len()).collect();
909
910 for delta_batch in state.all_delta() {
911 for row_idx in 0..delta_batch.num_rows() {
912 let row_hash = format!(
913 "{:?}",
914 extract_scalar_key(delta_batch, &all_indices, row_idx)
915 )
916 .into_bytes();
917 let fact_row = batch_row_to_value_map(delta_batch, row_idx);
918 let clause_index =
919 find_clause_for_row(delta_batch, row_idx, &all_indices, clause_candidates);
920
921 let entry = DerivationEntry {
922 rule_name: rule.name.clone(),
923 clause_index,
924 inputs: vec![],
925 along_values: std::collections::HashMap::new(),
926 iteration,
927 fact_row,
928 };
929 tracker.record(row_hash, entry);
930 }
931 }
932}
933
934fn find_clause_for_row(
939 delta_batch: &RecordBatch,
940 row_idx: usize,
941 all_indices: &[usize],
942 clause_candidates: &[Vec<RecordBatch>],
943) -> usize {
944 let target_key = extract_scalar_key(delta_batch, all_indices, row_idx);
945 for (clause_idx, batches) in clause_candidates.iter().enumerate() {
946 for batch in batches {
947 if batch.num_columns() != all_indices.len() {
948 continue;
949 }
950 for r in 0..batch.num_rows() {
951 if extract_scalar_key(batch, all_indices, r) == target_key {
952 return clause_idx;
953 }
954 }
955 }
956 }
957 0
958}
959
960fn batch_row_to_value_map(
962 batch: &RecordBatch,
963 row_idx: usize,
964) -> std::collections::HashMap<String, Value> {
965 use uni_store::storage::arrow_convert::arrow_to_value;
966
967 let schema = batch.schema();
968 schema
969 .fields()
970 .iter()
971 .enumerate()
972 .map(|(col_idx, field)| {
973 let col = batch.column(col_idx);
974 let val = arrow_to_value(col.as_ref(), row_idx, None);
975 (field.name().clone(), val)
976 })
977 .collect()
978}
979
980pub fn apply_anti_join(
985 batches: Vec<RecordBatch>,
986 neg_facts: &[RecordBatch],
987 left_col: &str,
988 right_col: &str,
989) -> datafusion::error::Result<Vec<RecordBatch>> {
990 use arrow::compute::filter_record_batch;
991 use arrow_array::{Array as _, BooleanArray, UInt64Array};
992
993 let mut banned: std::collections::HashSet<u64> = std::collections::HashSet::new();
995 for batch in neg_facts {
996 let Ok(idx) = batch.schema().index_of(right_col) else {
997 continue;
998 };
999 let arr = batch.column(idx);
1000 let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
1001 continue;
1002 };
1003 for i in 0..vids.len() {
1004 if !vids.is_null(i) {
1005 banned.insert(vids.value(i));
1006 }
1007 }
1008 }
1009
1010 if banned.is_empty() {
1011 return Ok(batches);
1012 }
1013
1014 let mut result = Vec::new();
1016 for batch in batches {
1017 let Ok(idx) = batch.schema().index_of(left_col) else {
1018 result.push(batch);
1019 continue;
1020 };
1021 let arr = batch.column(idx);
1022 let Some(vids) = arr.as_any().downcast_ref::<UInt64Array>() else {
1023 result.push(batch);
1024 continue;
1025 };
1026 let keep: Vec<bool> = (0..vids.len())
1027 .map(|i| vids.is_null(i) || !banned.contains(&vids.value(i)))
1028 .collect();
1029 let keep_arr = BooleanArray::from(keep);
1030 let filtered = filter_record_batch(&batch, &keep_arr)
1031 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
1032 if filtered.num_rows() > 0 {
1033 result.push(filtered);
1034 }
1035 }
1036 Ok(result)
1037}
1038
1039fn update_derived_scan_handles(
1044 registry: &DerivedScanRegistry,
1045 states: &[FixpointState],
1046 current_rule_idx: usize,
1047 rules: &[FixpointRulePlan],
1048) {
1049 let current_rule_name = &rules[current_rule_idx].name;
1050
1051 for entry in ®istry.entries {
1052 let source_state_idx = rules.iter().position(|r| r.name == entry.rule_name);
1054 let Some(source_idx) = source_state_idx else {
1055 continue;
1056 };
1057
1058 let is_self = entry.rule_name == *current_rule_name;
1059 let data = if is_self {
1060 states[source_idx].all_delta().to_vec()
1062 } else {
1063 states[source_idx].all_facts().to_vec()
1065 };
1066
1067 let data = if data.is_empty() || data.iter().all(|b| b.num_rows() == 0) {
1069 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
1070 } else {
1071 data
1072 };
1073
1074 let mut guard = entry.data.write();
1075 *guard = data;
1076 }
1077}
1078
1079pub struct DerivedScanExec {
1089 data: Arc<RwLock<Vec<RecordBatch>>>,
1090 schema: SchemaRef,
1091 properties: PlanProperties,
1092}
1093
1094impl DerivedScanExec {
1095 pub fn new(data: Arc<RwLock<Vec<RecordBatch>>>, schema: SchemaRef) -> Self {
1096 let properties = compute_plan_properties(Arc::clone(&schema));
1097 Self {
1098 data,
1099 schema,
1100 properties,
1101 }
1102 }
1103}
1104
1105impl fmt::Debug for DerivedScanExec {
1106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1107 f.debug_struct("DerivedScanExec")
1108 .field("schema", &self.schema)
1109 .finish()
1110 }
1111}
1112
1113impl DisplayAs for DerivedScanExec {
1114 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1115 write!(f, "DerivedScanExec")
1116 }
1117}
1118
1119impl ExecutionPlan for DerivedScanExec {
1120 fn name(&self) -> &str {
1121 "DerivedScanExec"
1122 }
1123 fn as_any(&self) -> &dyn Any {
1124 self
1125 }
1126 fn schema(&self) -> SchemaRef {
1127 Arc::clone(&self.schema)
1128 }
1129 fn properties(&self) -> &PlanProperties {
1130 &self.properties
1131 }
1132 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1133 vec![]
1134 }
1135 fn with_new_children(
1136 self: Arc<Self>,
1137 _children: Vec<Arc<dyn ExecutionPlan>>,
1138 ) -> DFResult<Arc<dyn ExecutionPlan>> {
1139 Ok(self)
1140 }
1141 fn execute(
1142 &self,
1143 _partition: usize,
1144 _context: Arc<TaskContext>,
1145 ) -> DFResult<SendableRecordBatchStream> {
1146 let batches = {
1147 let guard = self.data.read();
1148 if guard.is_empty() {
1149 vec![RecordBatch::new_empty(Arc::clone(&self.schema))]
1150 } else {
1151 guard.clone()
1152 }
1153 };
1154 Ok(Box::pin(MemoryStream::try_new(
1155 batches,
1156 Arc::clone(&self.schema),
1157 None,
1158 )?))
1159 }
1160}
1161
1162struct InMemoryExec {
1171 batches: Vec<RecordBatch>,
1172 schema: SchemaRef,
1173 properties: PlanProperties,
1174}
1175
1176impl InMemoryExec {
1177 fn new(batches: Vec<RecordBatch>, schema: SchemaRef) -> Self {
1178 let properties = compute_plan_properties(Arc::clone(&schema));
1179 Self {
1180 batches,
1181 schema,
1182 properties,
1183 }
1184 }
1185}
1186
1187impl fmt::Debug for InMemoryExec {
1188 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1189 f.debug_struct("InMemoryExec")
1190 .field("num_batches", &self.batches.len())
1191 .field("schema", &self.schema)
1192 .finish()
1193 }
1194}
1195
1196impl DisplayAs for InMemoryExec {
1197 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1198 write!(f, "InMemoryExec: batches={}", self.batches.len())
1199 }
1200}
1201
1202impl ExecutionPlan for InMemoryExec {
1203 fn name(&self) -> &str {
1204 "InMemoryExec"
1205 }
1206 fn as_any(&self) -> &dyn Any {
1207 self
1208 }
1209 fn schema(&self) -> SchemaRef {
1210 Arc::clone(&self.schema)
1211 }
1212 fn properties(&self) -> &PlanProperties {
1213 &self.properties
1214 }
1215 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1216 vec![]
1217 }
1218 fn with_new_children(
1219 self: Arc<Self>,
1220 _children: Vec<Arc<dyn ExecutionPlan>>,
1221 ) -> DFResult<Arc<dyn ExecutionPlan>> {
1222 Ok(self)
1223 }
1224 fn execute(
1225 &self,
1226 _partition: usize,
1227 _context: Arc<TaskContext>,
1228 ) -> DFResult<SendableRecordBatchStream> {
1229 Ok(Box::pin(MemoryStream::try_new(
1230 self.batches.clone(),
1231 Arc::clone(&self.schema),
1232 None,
1233 )?))
1234 }
1235}
1236
1237pub(crate) async fn apply_post_fixpoint_chain(
1243 facts: Vec<RecordBatch>,
1244 rule: &FixpointRulePlan,
1245 task_ctx: &Arc<TaskContext>,
1246) -> DFResult<Vec<RecordBatch>> {
1247 if !rule.has_fold && !rule.has_best_by && !rule.has_priority {
1248 return Ok(facts);
1249 }
1250
1251 let schema = Arc::clone(&rule.yield_schema);
1253 let input: Arc<dyn ExecutionPlan> = Arc::new(InMemoryExec::new(facts, schema));
1254
1255 let current: Arc<dyn ExecutionPlan> = if rule.has_priority {
1259 let priority_schema = input.schema();
1260 let priority_idx = priority_schema.index_of("__priority").map_err(|_| {
1261 datafusion::common::DataFusionError::Internal(
1262 "PRIORITY rule missing __priority column".to_string(),
1263 )
1264 })?;
1265 Arc::new(PriorityExec::new(
1266 input,
1267 rule.key_column_indices.clone(),
1268 priority_idx,
1269 ))
1270 } else {
1271 input
1272 };
1273
1274 let current: Arc<dyn ExecutionPlan> = if rule.has_fold && !rule.fold_bindings.is_empty() {
1276 Arc::new(FoldExec::new(
1277 current,
1278 rule.key_column_indices.clone(),
1279 rule.fold_bindings.clone(),
1280 ))
1281 } else {
1282 current
1283 };
1284
1285 let current: Arc<dyn ExecutionPlan> = if rule.has_best_by && !rule.best_by_criteria.is_empty() {
1287 Arc::new(BestByExec::new(
1288 current,
1289 rule.key_column_indices.clone(),
1290 rule.best_by_criteria.clone(),
1291 rule.deterministic,
1292 ))
1293 } else {
1294 current
1295 };
1296
1297 collect_all_partitions(¤t, Arc::clone(task_ctx)).await
1298}
1299
1300pub struct FixpointExec {
1309 rules: Vec<FixpointRulePlan>,
1310 max_iterations: usize,
1311 timeout: Duration,
1312 graph_ctx: Arc<GraphExecutionContext>,
1313 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
1314 storage: Arc<StorageManager>,
1315 schema_info: Arc<UniSchema>,
1316 params: HashMap<String, Value>,
1317 derived_scan_registry: Arc<DerivedScanRegistry>,
1318 output_schema: SchemaRef,
1319 properties: PlanProperties,
1320 metrics: ExecutionPlanMetricsSet,
1321 max_derived_bytes: usize,
1322 derivation_tracker: Option<Arc<DerivationTracker>>,
1324 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
1326}
1327
1328impl fmt::Debug for FixpointExec {
1329 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1330 f.debug_struct("FixpointExec")
1331 .field("rules_count", &self.rules.len())
1332 .field("max_iterations", &self.max_iterations)
1333 .field("timeout", &self.timeout)
1334 .field("output_schema", &self.output_schema)
1335 .field("max_derived_bytes", &self.max_derived_bytes)
1336 .finish_non_exhaustive()
1337 }
1338}
1339
1340impl FixpointExec {
1341 #[allow(clippy::too_many_arguments)]
1343 pub fn new(
1344 rules: Vec<FixpointRulePlan>,
1345 max_iterations: usize,
1346 timeout: Duration,
1347 graph_ctx: Arc<GraphExecutionContext>,
1348 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
1349 storage: Arc<StorageManager>,
1350 schema_info: Arc<UniSchema>,
1351 params: HashMap<String, Value>,
1352 derived_scan_registry: Arc<DerivedScanRegistry>,
1353 output_schema: SchemaRef,
1354 max_derived_bytes: usize,
1355 derivation_tracker: Option<Arc<DerivationTracker>>,
1356 iteration_counts: Arc<StdRwLock<HashMap<String, usize>>>,
1357 ) -> Self {
1358 let properties = compute_plan_properties(Arc::clone(&output_schema));
1359 Self {
1360 rules,
1361 max_iterations,
1362 timeout,
1363 graph_ctx,
1364 session_ctx,
1365 storage,
1366 schema_info,
1367 params,
1368 derived_scan_registry,
1369 output_schema,
1370 properties,
1371 metrics: ExecutionPlanMetricsSet::new(),
1372 max_derived_bytes,
1373 derivation_tracker,
1374 iteration_counts,
1375 }
1376 }
1377
1378 pub fn iteration_counts(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
1380 Arc::clone(&self.iteration_counts)
1381 }
1382}
1383
1384impl DisplayAs for FixpointExec {
1385 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1386 write!(
1387 f,
1388 "FixpointExec: rules=[{}], max_iter={}, timeout={:?}",
1389 self.rules
1390 .iter()
1391 .map(|r| r.name.as_str())
1392 .collect::<Vec<_>>()
1393 .join(", "),
1394 self.max_iterations,
1395 self.timeout,
1396 )
1397 }
1398}
1399
1400impl ExecutionPlan for FixpointExec {
1401 fn name(&self) -> &str {
1402 "FixpointExec"
1403 }
1404
1405 fn as_any(&self) -> &dyn Any {
1406 self
1407 }
1408
1409 fn schema(&self) -> SchemaRef {
1410 Arc::clone(&self.output_schema)
1411 }
1412
1413 fn properties(&self) -> &PlanProperties {
1414 &self.properties
1415 }
1416
1417 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
1418 vec![]
1420 }
1421
1422 fn with_new_children(
1423 self: Arc<Self>,
1424 children: Vec<Arc<dyn ExecutionPlan>>,
1425 ) -> DFResult<Arc<dyn ExecutionPlan>> {
1426 if !children.is_empty() {
1427 return Err(datafusion::error::DataFusionError::Plan(
1428 "FixpointExec has no children".to_string(),
1429 ));
1430 }
1431 Ok(self)
1432 }
1433
1434 fn execute(
1435 &self,
1436 partition: usize,
1437 _context: Arc<TaskContext>,
1438 ) -> DFResult<SendableRecordBatchStream> {
1439 let metrics = BaselineMetrics::new(&self.metrics, partition);
1440
1441 let rules = self
1443 .rules
1444 .iter()
1445 .map(|r| {
1446 FixpointRulePlan {
1450 name: r.name.clone(),
1451 clauses: r
1452 .clauses
1453 .iter()
1454 .map(|c| FixpointClausePlan {
1455 body_logical: c.body_logical.clone(),
1456 is_ref_bindings: c.is_ref_bindings.clone(),
1457 priority: c.priority,
1458 })
1459 .collect(),
1460 yield_schema: Arc::clone(&r.yield_schema),
1461 key_column_indices: r.key_column_indices.clone(),
1462 priority: r.priority,
1463 has_fold: r.has_fold,
1464 fold_bindings: r.fold_bindings.clone(),
1465 has_best_by: r.has_best_by,
1466 best_by_criteria: r.best_by_criteria.clone(),
1467 has_priority: r.has_priority,
1468 deterministic: r.deterministic,
1469 }
1470 })
1471 .collect();
1472
1473 let max_iterations = self.max_iterations;
1474 let timeout = self.timeout;
1475 let graph_ctx = Arc::clone(&self.graph_ctx);
1476 let session_ctx = Arc::clone(&self.session_ctx);
1477 let storage = Arc::clone(&self.storage);
1478 let schema_info = Arc::clone(&self.schema_info);
1479 let params = self.params.clone();
1480 let registry = Arc::clone(&self.derived_scan_registry);
1481 let output_schema = Arc::clone(&self.output_schema);
1482 let max_derived_bytes = self.max_derived_bytes;
1483 let derivation_tracker = self.derivation_tracker.clone();
1484 let iteration_counts = Arc::clone(&self.iteration_counts);
1485
1486 let fut = async move {
1487 run_fixpoint_loop(
1488 rules,
1489 max_iterations,
1490 timeout,
1491 graph_ctx,
1492 session_ctx,
1493 storage,
1494 schema_info,
1495 params,
1496 registry,
1497 output_schema,
1498 max_derived_bytes,
1499 derivation_tracker,
1500 iteration_counts,
1501 )
1502 .await
1503 };
1504
1505 Ok(Box::pin(FixpointStream {
1506 state: FixpointStreamState::Running(Box::pin(fut)),
1507 schema: Arc::clone(&self.output_schema),
1508 metrics,
1509 }))
1510 }
1511
1512 fn metrics(&self) -> Option<MetricsSet> {
1513 Some(self.metrics.clone_inner())
1514 }
1515}
1516
1517enum FixpointStreamState {
1522 Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
1524 Emitting(Vec<RecordBatch>, usize),
1526 Done,
1528}
1529
1530struct FixpointStream {
1531 state: FixpointStreamState,
1532 schema: SchemaRef,
1533 metrics: BaselineMetrics,
1534}
1535
1536impl Stream for FixpointStream {
1537 type Item = DFResult<RecordBatch>;
1538
1539 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1540 let this = self.get_mut();
1541 loop {
1542 match &mut this.state {
1543 FixpointStreamState::Running(fut) => match fut.as_mut().poll(cx) {
1544 Poll::Ready(Ok(batches)) => {
1545 if batches.is_empty() {
1546 this.state = FixpointStreamState::Done;
1547 return Poll::Ready(None);
1548 }
1549 this.state = FixpointStreamState::Emitting(batches, 0);
1550 }
1552 Poll::Ready(Err(e)) => {
1553 this.state = FixpointStreamState::Done;
1554 return Poll::Ready(Some(Err(e)));
1555 }
1556 Poll::Pending => return Poll::Pending,
1557 },
1558 FixpointStreamState::Emitting(batches, idx) => {
1559 if *idx >= batches.len() {
1560 this.state = FixpointStreamState::Done;
1561 return Poll::Ready(None);
1562 }
1563 let batch = batches[*idx].clone();
1564 *idx += 1;
1565 this.metrics.record_output(batch.num_rows());
1566 return Poll::Ready(Some(Ok(batch)));
1567 }
1568 FixpointStreamState::Done => return Poll::Ready(None),
1569 }
1570 }
1571 }
1572}
1573
1574impl RecordBatchStream for FixpointStream {
1575 fn schema(&self) -> SchemaRef {
1576 Arc::clone(&self.schema)
1577 }
1578}
1579
1580#[cfg(test)]
1585mod tests {
1586 use super::*;
1587 use arrow_array::{Float64Array, Int64Array, StringArray};
1588 use arrow_schema::{DataType, Field, Schema};
1589
1590 fn test_schema() -> SchemaRef {
1591 Arc::new(Schema::new(vec![
1592 Field::new("name", DataType::Utf8, true),
1593 Field::new("value", DataType::Int64, true),
1594 ]))
1595 }
1596
1597 fn make_batch(names: &[&str], values: &[i64]) -> RecordBatch {
1598 RecordBatch::try_new(
1599 test_schema(),
1600 vec![
1601 Arc::new(StringArray::from(
1602 names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
1603 )),
1604 Arc::new(Int64Array::from(values.to_vec())),
1605 ],
1606 )
1607 .unwrap()
1608 }
1609
1610 #[tokio::test]
1613 async fn test_fixpoint_state_empty_facts_adds_all() {
1614 let schema = test_schema();
1615 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None);
1616
1617 let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
1618 let changed = state.merge_delta(vec![batch], None).await.unwrap();
1619
1620 assert!(changed);
1621 assert_eq!(state.all_facts().len(), 1);
1622 assert_eq!(state.all_facts()[0].num_rows(), 3);
1623 assert_eq!(state.all_delta().len(), 1);
1624 assert_eq!(state.all_delta()[0].num_rows(), 3);
1625 }
1626
1627 #[tokio::test]
1628 async fn test_fixpoint_state_exact_duplicates_excluded() {
1629 let schema = test_schema();
1630 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None);
1631
1632 let batch1 = make_batch(&["a", "b"], &[1, 2]);
1633 state.merge_delta(vec![batch1], None).await.unwrap();
1634
1635 let batch2 = make_batch(&["a", "b"], &[1, 2]);
1637 let changed = state.merge_delta(vec![batch2], None).await.unwrap();
1638 assert!(!changed);
1639 assert!(
1640 state.all_delta().is_empty() || state.all_delta().iter().all(|b| b.num_rows() == 0)
1641 );
1642 }
1643
1644 #[tokio::test]
1645 async fn test_fixpoint_state_partial_overlap() {
1646 let schema = test_schema();
1647 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None);
1648
1649 let batch1 = make_batch(&["a", "b"], &[1, 2]);
1650 state.merge_delta(vec![batch1], None).await.unwrap();
1651
1652 let batch2 = make_batch(&["a", "c"], &[1, 3]);
1654 let changed = state.merge_delta(vec![batch2], None).await.unwrap();
1655 assert!(changed);
1656
1657 let delta_rows: usize = state.all_delta().iter().map(|b| b.num_rows()).sum();
1659 assert_eq!(delta_rows, 1);
1660
1661 let total_rows: usize = state.all_facts().iter().map(|b| b.num_rows()).sum();
1663 assert_eq!(total_rows, 3);
1664 }
1665
1666 #[tokio::test]
1667 async fn test_fixpoint_state_convergence() {
1668 let schema = test_schema();
1669 let mut state = FixpointState::new("test".into(), schema, vec![0], 1_000_000, None);
1670
1671 let batch = make_batch(&["a"], &[1]);
1672 state.merge_delta(vec![batch], None).await.unwrap();
1673
1674 let changed = state.merge_delta(vec![], None).await.unwrap();
1676 assert!(!changed);
1677 assert!(state.is_converged());
1678 }
1679
1680 #[test]
1683 fn test_row_dedup_persistent_across_calls() {
1684 let schema = test_schema();
1687 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
1688
1689 let batch1 = make_batch(&["a", "b"], &[1, 2]);
1690 let delta1 = rd.compute_delta(&[batch1], &schema).unwrap();
1691 let rows1: usize = delta1.iter().map(|b| b.num_rows()).sum();
1693 assert_eq!(rows1, 2);
1694
1695 let batch2 = make_batch(&["a", "b"], &[1, 2]);
1697 let delta2 = rd.compute_delta(&[batch2], &schema).unwrap();
1698 let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
1699 assert_eq!(rows2, 0);
1700
1701 let batch3 = make_batch(&["a", "c"], &[1, 3]);
1703 let delta3 = rd.compute_delta(&[batch3], &schema).unwrap();
1704 let rows3: usize = delta3.iter().map(|b| b.num_rows()).sum();
1705 assert_eq!(rows3, 1);
1706 }
1707
1708 #[test]
1709 fn test_row_dedup_null_handling() {
1710 use arrow_array::StringArray;
1711 use arrow_schema::{DataType, Field, Schema};
1712
1713 let schema: SchemaRef = Arc::new(Schema::new(vec![
1714 Field::new("a", DataType::Utf8, true),
1715 Field::new("b", DataType::Int64, true),
1716 ]));
1717 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
1718
1719 let batch_nulls = RecordBatch::try_new(
1721 Arc::clone(&schema),
1722 vec![
1723 Arc::new(StringArray::from(vec![None::<&str>, None::<&str>])),
1724 Arc::new(arrow_array::Int64Array::from(vec![1i64, 1i64])),
1725 ],
1726 )
1727 .unwrap();
1728 let delta = rd.compute_delta(&[batch_nulls], &schema).unwrap();
1729 let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
1730 assert_eq!(rows, 1, "two identical NULL rows should be deduped to one");
1731
1732 let batch_diff = RecordBatch::try_new(
1734 Arc::clone(&schema),
1735 vec![
1736 Arc::new(StringArray::from(vec![None::<&str>])),
1737 Arc::new(arrow_array::Int64Array::from(vec![2i64])),
1738 ],
1739 )
1740 .unwrap();
1741 let delta2 = rd.compute_delta(&[batch_diff], &schema).unwrap();
1742 let rows2: usize = delta2.iter().map(|b| b.num_rows()).sum();
1743 assert_eq!(rows2, 1, "(NULL, 2) is distinct from (NULL, 1)");
1744 }
1745
1746 #[test]
1747 fn test_row_dedup_within_candidate_dedup() {
1748 let schema = test_schema();
1750 let mut rd = RowDedupState::try_new(&schema).expect("schema should be supported");
1751
1752 let batch = make_batch(&["a", "a", "b"], &[1, 1, 2]);
1754 let delta = rd.compute_delta(&[batch], &schema).unwrap();
1755 let rows: usize = delta.iter().map(|b| b.num_rows()).sum();
1756 assert_eq!(rows, 2, "within-batch dup should be collapsed: a:1, b:2");
1757 }
1758
1759 #[test]
1762 fn test_round_float_columns_near_duplicates() {
1763 let schema = Arc::new(Schema::new(vec![
1764 Field::new("name", DataType::Utf8, true),
1765 Field::new("dist", DataType::Float64, true),
1766 ]));
1767 let batch = RecordBatch::try_new(
1768 Arc::clone(&schema),
1769 vec![
1770 Arc::new(StringArray::from(vec![Some("a"), Some("a")])),
1771 Arc::new(Float64Array::from(vec![1.0000000000001, 1.0000000000002])),
1772 ],
1773 )
1774 .unwrap();
1775
1776 let rounded = round_float_columns(&[batch]);
1777 assert_eq!(rounded.len(), 1);
1778 let col = rounded[0]
1779 .column(1)
1780 .as_any()
1781 .downcast_ref::<Float64Array>()
1782 .unwrap();
1783 assert_eq!(col.value(0), col.value(1));
1785 }
1786
1787 #[test]
1790 fn test_registry_write_read_round_trip() {
1791 let schema = test_schema();
1792 let data = Arc::new(RwLock::new(Vec::new()));
1793 let mut reg = DerivedScanRegistry::new();
1794 reg.add(DerivedScanEntry {
1795 scan_index: 0,
1796 rule_name: "reachable".into(),
1797 is_self_ref: true,
1798 data: Arc::clone(&data),
1799 schema: Arc::clone(&schema),
1800 });
1801
1802 let batch = make_batch(&["x"], &[42]);
1803 reg.write_data(0, vec![batch.clone()]);
1804
1805 let entry = reg.get(0).unwrap();
1806 let guard = entry.data.read();
1807 assert_eq!(guard.len(), 1);
1808 assert_eq!(guard[0].num_rows(), 1);
1809 }
1810
1811 #[test]
1812 fn test_registry_entries_for_rule() {
1813 let schema = test_schema();
1814 let mut reg = DerivedScanRegistry::new();
1815 reg.add(DerivedScanEntry {
1816 scan_index: 0,
1817 rule_name: "r1".into(),
1818 is_self_ref: true,
1819 data: Arc::new(RwLock::new(Vec::new())),
1820 schema: Arc::clone(&schema),
1821 });
1822 reg.add(DerivedScanEntry {
1823 scan_index: 1,
1824 rule_name: "r2".into(),
1825 is_self_ref: false,
1826 data: Arc::new(RwLock::new(Vec::new())),
1827 schema: Arc::clone(&schema),
1828 });
1829 reg.add(DerivedScanEntry {
1830 scan_index: 2,
1831 rule_name: "r1".into(),
1832 is_self_ref: false,
1833 data: Arc::new(RwLock::new(Vec::new())),
1834 schema: Arc::clone(&schema),
1835 });
1836
1837 assert_eq!(reg.entries_for_rule("r1").len(), 2);
1838 assert_eq!(reg.entries_for_rule("r2").len(), 1);
1839 assert_eq!(reg.entries_for_rule("r3").len(), 0);
1840 }
1841
1842 #[test]
1845 fn test_monotonic_agg_update_and_stability() {
1846 use crate::query::df_graph::locy_fold::FoldAggKind;
1847
1848 let bindings = vec![MonotonicFoldBinding {
1849 fold_name: "total".into(),
1850 kind: FoldAggKind::Sum,
1851 input_col_index: 1,
1852 }];
1853 let mut agg = MonotonicAggState::new(bindings);
1854
1855 let batch = make_batch(&["a"], &[10]);
1857 agg.snapshot();
1858 let changed = agg.update(&[0], &[batch]);
1859 assert!(changed);
1860 assert!(!agg.is_stable()); agg.snapshot();
1864 let changed = agg.update(&[0], &[]);
1865 assert!(!changed);
1866 assert!(agg.is_stable());
1867 }
1868
1869 #[tokio::test]
1872 async fn test_memory_limit_exceeded() {
1873 let schema = test_schema();
1874 let mut state = FixpointState::new("test".into(), schema, vec![0], 1, None);
1876
1877 let batch = make_batch(&["a", "b", "c"], &[1, 2, 3]);
1878 let result = state.merge_delta(vec![batch], None).await;
1879 assert!(result.is_err());
1880 let err = result.unwrap_err().to_string();
1881 assert!(err.contains("memory limit"), "Error was: {}", err);
1882 }
1883
1884 #[tokio::test]
1887 async fn test_fixpoint_stream_emitting() {
1888 use futures::StreamExt;
1889
1890 let schema = test_schema();
1891 let batch1 = make_batch(&["a"], &[1]);
1892 let batch2 = make_batch(&["b"], &[2]);
1893
1894 let metrics = ExecutionPlanMetricsSet::new();
1895 let baseline = BaselineMetrics::new(&metrics, 0);
1896
1897 let mut stream = FixpointStream {
1898 state: FixpointStreamState::Emitting(vec![batch1, batch2], 0),
1899 schema,
1900 metrics: baseline,
1901 };
1902
1903 let stream = Pin::new(&mut stream);
1904 let batches: Vec<RecordBatch> = stream.filter_map(|r| async { r.ok() }).collect().await;
1905
1906 assert_eq!(batches.len(), 2);
1907 assert_eq!(batches[0].num_rows(), 1);
1908 assert_eq!(batches[1].num_rows(), 1);
1909 }
1910}