1use crate::query::df_graph::GraphExecutionContext;
12use crate::query::df_graph::common::{
13 collect_all_partitions, compute_plan_properties, execute_subplan,
14};
15use crate::query::df_graph::locy_best_by::SortCriterion;
16use crate::query::df_graph::locy_explain::DerivationTracker;
17use crate::query::df_graph::locy_fixpoint::{
18 DerivedScanRegistry, FixpointClausePlan, FixpointExec, FixpointRulePlan, IsRefBinding,
19};
20use crate::query::df_graph::locy_fold::{FoldAggKind, FoldBinding};
21use crate::query::planner_locy_types::{
22 LocyCommand, LocyIsRef, LocyRulePlan, LocyStratum, LocyYieldColumn,
23};
24use arrow_array::RecordBatch;
25use arrow_schema::{DataType, Field, Schema as ArrowSchema, SchemaRef};
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
29use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
30use futures::Stream;
31use parking_lot::RwLock;
32use std::any::Any;
33use std::collections::HashMap;
34use std::fmt;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::sync::RwLock as StdRwLock;
38use std::task::{Context, Poll};
39use std::time::{Duration, Instant};
40use uni_common::Value;
41use uni_common::core::schema::Schema as UniSchema;
42use uni_cypher::ast::Expr;
43use uni_store::storage::manager::StorageManager;
44
45pub struct DerivedStore {
54 relations: HashMap<String, Vec<RecordBatch>>,
55}
56
57impl Default for DerivedStore {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl DerivedStore {
64 pub fn new() -> Self {
65 Self {
66 relations: HashMap::new(),
67 }
68 }
69
70 pub fn insert(&mut self, rule_name: String, facts: Vec<RecordBatch>) {
71 self.relations.insert(rule_name, facts);
72 }
73
74 pub fn get(&self, rule_name: &str) -> Option<&Vec<RecordBatch>> {
75 self.relations.get(rule_name)
76 }
77
78 pub fn fact_count(&self, rule_name: &str) -> usize {
79 self.relations
80 .get(rule_name)
81 .map(|batches| batches.iter().map(|b| b.num_rows()).sum())
82 .unwrap_or(0)
83 }
84
85 pub fn rule_names(&self) -> impl Iterator<Item = &str> {
86 self.relations.keys().map(|s| s.as_str())
87 }
88}
89
90pub struct LocyProgramExec {
100 strata: Vec<LocyStratum>,
101 commands: Vec<LocyCommand>,
102 derived_scan_registry: Arc<DerivedScanRegistry>,
103 graph_ctx: Arc<GraphExecutionContext>,
104 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
105 storage: Arc<StorageManager>,
106 schema_info: Arc<UniSchema>,
107 params: HashMap<String, Value>,
108 output_schema: SchemaRef,
109 properties: PlanProperties,
110 metrics: ExecutionPlanMetricsSet,
111 max_iterations: usize,
112 timeout: Duration,
113 max_derived_bytes: usize,
114 deterministic_best_by: bool,
115 derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
117 derivation_tracker: Arc<StdRwLock<Option<Arc<DerivationTracker>>>>,
119 iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
121 peak_memory_slot: Arc<StdRwLock<usize>>,
123}
124
125impl fmt::Debug for LocyProgramExec {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 f.debug_struct("LocyProgramExec")
128 .field("strata_count", &self.strata.len())
129 .field("commands_count", &self.commands.len())
130 .field("max_iterations", &self.max_iterations)
131 .field("timeout", &self.timeout)
132 .field("output_schema", &self.output_schema)
133 .field("max_derived_bytes", &self.max_derived_bytes)
134 .finish_non_exhaustive()
135 }
136}
137
138impl LocyProgramExec {
139 #[allow(clippy::too_many_arguments)]
140 pub fn new(
141 strata: Vec<LocyStratum>,
142 commands: Vec<LocyCommand>,
143 derived_scan_registry: Arc<DerivedScanRegistry>,
144 graph_ctx: Arc<GraphExecutionContext>,
145 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
146 storage: Arc<StorageManager>,
147 schema_info: Arc<UniSchema>,
148 params: HashMap<String, Value>,
149 output_schema: SchemaRef,
150 max_iterations: usize,
151 timeout: Duration,
152 max_derived_bytes: usize,
153 deterministic_best_by: bool,
154 ) -> Self {
155 let properties = compute_plan_properties(Arc::clone(&output_schema));
156 Self {
157 strata,
158 commands,
159 derived_scan_registry,
160 graph_ctx,
161 session_ctx,
162 storage,
163 schema_info,
164 params,
165 output_schema,
166 properties,
167 metrics: ExecutionPlanMetricsSet::new(),
168 max_iterations,
169 timeout,
170 max_derived_bytes,
171 deterministic_best_by,
172 derived_store_slot: Arc::new(StdRwLock::new(None)),
173 derivation_tracker: Arc::new(StdRwLock::new(None)),
174 iteration_counts_slot: Arc::new(StdRwLock::new(HashMap::new())),
175 peak_memory_slot: Arc::new(StdRwLock::new(0)),
176 }
177 }
178
179 pub fn derived_store_slot(&self) -> Arc<StdRwLock<Option<DerivedStore>>> {
184 Arc::clone(&self.derived_store_slot)
185 }
186
187 pub fn set_derivation_tracker(&self, tracker: Arc<DerivationTracker>) {
192 if let Ok(mut guard) = self.derivation_tracker.write() {
193 *guard = Some(tracker);
194 }
195 }
196
197 pub fn iteration_counts_slot(&self) -> Arc<StdRwLock<HashMap<String, usize>>> {
202 Arc::clone(&self.iteration_counts_slot)
203 }
204
205 pub fn peak_memory_slot(&self) -> Arc<StdRwLock<usize>> {
210 Arc::clone(&self.peak_memory_slot)
211 }
212}
213
214impl DisplayAs for LocyProgramExec {
215 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216 write!(
217 f,
218 "LocyProgramExec: strata={}, commands={}, max_iter={}, timeout={:?}",
219 self.strata.len(),
220 self.commands.len(),
221 self.max_iterations,
222 self.timeout,
223 )
224 }
225}
226
227impl ExecutionPlan for LocyProgramExec {
228 fn name(&self) -> &str {
229 "LocyProgramExec"
230 }
231
232 fn as_any(&self) -> &dyn Any {
233 self
234 }
235
236 fn schema(&self) -> SchemaRef {
237 Arc::clone(&self.output_schema)
238 }
239
240 fn properties(&self) -> &PlanProperties {
241 &self.properties
242 }
243
244 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
245 vec![]
246 }
247
248 fn with_new_children(
249 self: Arc<Self>,
250 children: Vec<Arc<dyn ExecutionPlan>>,
251 ) -> DFResult<Arc<dyn ExecutionPlan>> {
252 if !children.is_empty() {
253 return Err(datafusion::error::DataFusionError::Plan(
254 "LocyProgramExec has no children".to_string(),
255 ));
256 }
257 Ok(self)
258 }
259
260 fn execute(
261 &self,
262 partition: usize,
263 _context: Arc<TaskContext>,
264 ) -> DFResult<SendableRecordBatchStream> {
265 let metrics = BaselineMetrics::new(&self.metrics, partition);
266
267 let strata = self.strata.clone();
268 let registry = Arc::clone(&self.derived_scan_registry);
269 let graph_ctx = Arc::clone(&self.graph_ctx);
270 let session_ctx = Arc::clone(&self.session_ctx);
271 let storage = Arc::clone(&self.storage);
272 let schema_info = Arc::clone(&self.schema_info);
273 let params = self.params.clone();
274 let output_schema = Arc::clone(&self.output_schema);
275 let max_iterations = self.max_iterations;
276 let timeout = self.timeout;
277 let max_derived_bytes = self.max_derived_bytes;
278 let deterministic_best_by = self.deterministic_best_by;
279 let derived_store_slot = Arc::clone(&self.derived_store_slot);
280 let iteration_counts_slot = Arc::clone(&self.iteration_counts_slot);
281 let peak_memory_slot = Arc::clone(&self.peak_memory_slot);
282 let derivation_tracker = self.derivation_tracker.read().ok().and_then(|g| g.clone());
283
284 let fut = async move {
285 run_program(
286 strata,
287 registry,
288 graph_ctx,
289 session_ctx,
290 storage,
291 schema_info,
292 params,
293 output_schema,
294 max_iterations,
295 timeout,
296 max_derived_bytes,
297 deterministic_best_by,
298 derived_store_slot,
299 iteration_counts_slot,
300 peak_memory_slot,
301 derivation_tracker,
302 )
303 .await
304 };
305
306 Ok(Box::pin(ProgramStream {
307 state: ProgramStreamState::Running(Box::pin(fut)),
308 schema: Arc::clone(&self.output_schema),
309 metrics,
310 }))
311 }
312
313 fn metrics(&self) -> Option<MetricsSet> {
314 Some(self.metrics.clone_inner())
315 }
316}
317
318enum ProgramStreamState {
323 Running(Pin<Box<dyn std::future::Future<Output = DFResult<Vec<RecordBatch>>> + Send>>),
324 Emitting(Vec<RecordBatch>, usize),
325 Done,
326}
327
328struct ProgramStream {
329 state: ProgramStreamState,
330 schema: SchemaRef,
331 metrics: BaselineMetrics,
332}
333
334impl Stream for ProgramStream {
335 type Item = DFResult<RecordBatch>;
336
337 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
338 let this = self.get_mut();
339 loop {
340 match &mut this.state {
341 ProgramStreamState::Running(fut) => match fut.as_mut().poll(cx) {
342 Poll::Ready(Ok(batches)) => {
343 if batches.is_empty() {
344 this.state = ProgramStreamState::Done;
345 return Poll::Ready(None);
346 }
347 this.state = ProgramStreamState::Emitting(batches, 0);
348 }
349 Poll::Ready(Err(e)) => {
350 this.state = ProgramStreamState::Done;
351 return Poll::Ready(Some(Err(e)));
352 }
353 Poll::Pending => return Poll::Pending,
354 },
355 ProgramStreamState::Emitting(batches, idx) => {
356 if *idx >= batches.len() {
357 this.state = ProgramStreamState::Done;
358 return Poll::Ready(None);
359 }
360 let batch = batches[*idx].clone();
361 *idx += 1;
362 this.metrics.record_output(batch.num_rows());
363 return Poll::Ready(Some(Ok(batch)));
364 }
365 ProgramStreamState::Done => return Poll::Ready(None),
366 }
367 }
368 }
369}
370
371impl RecordBatchStream for ProgramStream {
372 fn schema(&self) -> SchemaRef {
373 Arc::clone(&self.schema)
374 }
375}
376
377#[allow(clippy::too_many_arguments)]
382async fn run_program(
383 strata: Vec<LocyStratum>,
384 registry: Arc<DerivedScanRegistry>,
385 graph_ctx: Arc<GraphExecutionContext>,
386 session_ctx: Arc<RwLock<datafusion::prelude::SessionContext>>,
387 storage: Arc<StorageManager>,
388 schema_info: Arc<UniSchema>,
389 params: HashMap<String, Value>,
390 output_schema: SchemaRef,
391 max_iterations: usize,
392 timeout: Duration,
393 max_derived_bytes: usize,
394 deterministic_best_by: bool,
395 derived_store_slot: Arc<StdRwLock<Option<DerivedStore>>>,
396 iteration_counts_slot: Arc<StdRwLock<HashMap<String, usize>>>,
397 peak_memory_slot: Arc<StdRwLock<usize>>,
398 derivation_tracker: Option<Arc<DerivationTracker>>,
399) -> DFResult<Vec<RecordBatch>> {
400 let start = Instant::now();
401 let mut derived_store = DerivedStore::new();
402
403 for stratum in &strata {
405 write_cross_stratum_facts(®istry, &derived_store, stratum);
407
408 let remaining_timeout = timeout.saturating_sub(start.elapsed());
409 if remaining_timeout.is_zero() {
410 return Err(datafusion::error::DataFusionError::Execution(
411 "Locy program timeout exceeded during stratum evaluation".to_string(),
412 ));
413 }
414
415 if stratum.is_recursive {
416 let fixpoint_rules =
418 convert_to_fixpoint_plans(&stratum.rules, ®istry, deterministic_best_by)?;
419 let fixpoint_schema = build_fixpoint_output_schema(&stratum.rules);
420
421 let exec = FixpointExec::new(
422 fixpoint_rules,
423 max_iterations,
424 remaining_timeout,
425 Arc::clone(&graph_ctx),
426 Arc::clone(&session_ctx),
427 Arc::clone(&storage),
428 Arc::clone(&schema_info),
429 params.clone(),
430 Arc::clone(®istry),
431 fixpoint_schema,
432 max_derived_bytes,
433 derivation_tracker.clone(),
434 Arc::clone(&iteration_counts_slot),
435 );
436
437 let task_ctx = session_ctx.read().task_ctx();
438 let exec_arc: Arc<dyn ExecutionPlan> = Arc::new(exec);
439 let batches = collect_all_partitions(&exec_arc, task_ctx).await?;
440
441 for rule in &stratum.rules {
447 let rule_entries = registry.entries_for_rule(&rule.name);
449 for entry in rule_entries {
450 if !entry.is_self_ref {
451 let all_facts: Vec<RecordBatch> = batches
455 .iter()
456 .filter(|b| {
457 let rule_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
459 b.schema().fields().len() == rule_schema.fields().len()
460 })
461 .cloned()
462 .collect();
463 let mut guard = entry.data.write();
464 *guard = if all_facts.is_empty() {
465 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
466 } else {
467 all_facts
468 };
469 }
470 }
471 derived_store.insert(rule.name.clone(), batches.clone());
472 }
473 } else {
474 let fixpoint_rules =
476 convert_to_fixpoint_plans(&stratum.rules, ®istry, deterministic_best_by)?;
477 let task_ctx = session_ctx.read().task_ctx();
478
479 for (rule, fp_rule) in stratum.rules.iter().zip(fixpoint_rules.iter()) {
480 let mut facts = evaluate_non_recursive_rule(
481 rule,
482 ¶ms,
483 &graph_ctx,
484 &session_ctx,
485 &storage,
486 &schema_info,
487 )
488 .await?;
489
490 for clause in &fp_rule.clauses {
494 for binding in &clause.is_ref_bindings {
495 if binding.negated
496 && !binding.anti_join_cols.is_empty()
497 && let Some(entry) = registry.get(binding.derived_scan_index)
498 {
499 let neg_facts = entry.data.read().clone();
500 if !neg_facts.is_empty() {
501 for (left_col, right_col) in &binding.anti_join_cols {
502 facts = super::locy_fixpoint::apply_anti_join(
503 facts, &neg_facts, left_col, right_col,
504 )?;
505 }
506 }
507 }
508 }
509 }
510
511 let facts =
513 super::locy_fixpoint::apply_post_fixpoint_chain(facts, fp_rule, &task_ctx)
514 .await?;
515
516 write_facts_to_registry(®istry, &rule.name, &facts);
518 derived_store.insert(rule.name.clone(), facts);
519 }
520 }
521 }
522
523 let peak_bytes: usize = derived_store
525 .relations
526 .values()
527 .flat_map(|batches| batches.iter())
528 .map(|b| {
529 b.columns()
530 .iter()
531 .map(|col| col.get_buffer_memory_size())
532 .sum::<usize>()
533 })
534 .sum();
535 *peak_memory_slot.write().unwrap() = peak_bytes;
536
537 let stats = vec![build_stats_batch(&derived_store, &strata, output_schema)];
541 *derived_store_slot.write().unwrap() = Some(derived_store);
542 Ok(stats)
543}
544
545async fn evaluate_non_recursive_rule(
550 rule: &LocyRulePlan,
551 params: &HashMap<String, Value>,
552 graph_ctx: &Arc<GraphExecutionContext>,
553 session_ctx: &Arc<RwLock<datafusion::prelude::SessionContext>>,
554 storage: &Arc<StorageManager>,
555 schema_info: &Arc<UniSchema>,
556) -> DFResult<Vec<RecordBatch>> {
557 let mut all_batches = Vec::new();
558
559 for clause in &rule.clauses {
560 let batches = execute_subplan(
561 &clause.body,
562 params,
563 &HashMap::new(),
564 graph_ctx,
565 session_ctx,
566 storage,
567 schema_info,
568 )
569 .await?;
570 all_batches.extend(batches);
571 }
572
573 Ok(all_batches)
574}
575
576fn write_cross_stratum_facts(
582 registry: &DerivedScanRegistry,
583 derived_store: &DerivedStore,
584 stratum: &LocyStratum,
585) {
586 for rule in &stratum.rules {
588 for clause in &rule.clauses {
589 for is_ref in &clause.is_refs {
590 if let Some(facts) = derived_store.get(&is_ref.rule_name) {
593 write_facts_to_registry(registry, &is_ref.rule_name, facts);
594 }
595 }
596 }
597 }
598}
599
600fn write_facts_to_registry(registry: &DerivedScanRegistry, rule_name: &str, facts: &[RecordBatch]) {
602 let entries = registry.entries_for_rule(rule_name);
603 for entry in entries {
604 if !entry.is_self_ref {
605 let mut guard = entry.data.write();
606 *guard = if facts.is_empty() || facts.iter().all(|b| b.num_rows() == 0) {
607 vec![RecordBatch::new_empty(Arc::clone(&entry.schema))]
608 } else {
609 facts
613 .iter()
614 .filter_map(|b| {
615 RecordBatch::try_new(Arc::clone(&entry.schema), b.columns().to_vec()).ok()
616 })
617 .collect()
618 };
619 }
620 }
621}
622
623fn convert_to_fixpoint_plans(
629 rules: &[LocyRulePlan],
630 registry: &DerivedScanRegistry,
631 deterministic_best_by: bool,
632) -> DFResult<Vec<FixpointRulePlan>> {
633 rules
634 .iter()
635 .map(|rule| {
636 let yield_schema = yield_columns_to_arrow_schema(&rule.yield_schema);
637 let key_column_indices: Vec<usize> = rule
638 .yield_schema
639 .iter()
640 .enumerate()
641 .filter(|(_, yc)| yc.is_key)
642 .map(|(i, _)| i)
643 .collect();
644
645 let clauses: Vec<FixpointClausePlan> = rule
646 .clauses
647 .iter()
648 .map(|clause| {
649 let is_ref_bindings = convert_is_refs(&clause.is_refs, registry)?;
650 Ok(FixpointClausePlan {
651 body_logical: clause.body.clone(),
652 is_ref_bindings,
653 priority: clause.priority,
654 })
655 })
656 .collect::<DFResult<Vec<_>>>()?;
657
658 let fold_bindings = convert_fold_bindings(&rule.fold_bindings, &rule.yield_schema)?;
659 let best_by_criteria =
660 convert_best_by_criteria(&rule.best_by_criteria, &rule.yield_schema)?;
661
662 let has_priority = rule.priority.is_some();
663
664 let yield_schema = if has_priority {
666 let mut fields: Vec<Arc<Field>> = yield_schema.fields().iter().cloned().collect();
667 fields.push(Arc::new(Field::new("__priority", DataType::Int64, true)));
668 ArrowSchema::new(fields)
669 } else {
670 yield_schema
671 };
672
673 Ok(FixpointRulePlan {
674 name: rule.name.clone(),
675 clauses,
676 yield_schema: Arc::new(yield_schema),
677 key_column_indices,
678 priority: rule.priority,
679 has_fold: !rule.fold_bindings.is_empty(),
680 fold_bindings,
681 has_best_by: !rule.best_by_criteria.is_empty(),
682 best_by_criteria,
683 has_priority,
684 deterministic: deterministic_best_by,
685 })
686 })
687 .collect()
688}
689
690fn convert_is_refs(
692 is_refs: &[LocyIsRef],
693 registry: &DerivedScanRegistry,
694) -> DFResult<Vec<IsRefBinding>> {
695 is_refs
696 .iter()
697 .map(|is_ref| {
698 let entries = registry.entries_for_rule(&is_ref.rule_name);
699 let entry = entries
701 .iter()
702 .find(|e| e.is_self_ref)
703 .or_else(|| entries.first())
704 .ok_or_else(|| {
705 datafusion::error::DataFusionError::Plan(format!(
706 "No derived scan entry found for IS-ref to '{}'",
707 is_ref.rule_name
708 ))
709 })?;
710
711 let anti_join_cols = if is_ref.negated {
716 is_ref
717 .subjects
718 .iter()
719 .enumerate()
720 .filter_map(|(i, s)| {
721 if let uni_cypher::ast::Expr::Variable(var) = s {
722 let right_col = entry
723 .schema
724 .fields()
725 .get(i)
726 .map(|f| f.name().clone())
727 .unwrap_or_else(|| var.clone());
728 Some((var.clone(), right_col))
731 } else {
732 None
733 }
734 })
735 .collect()
736 } else {
737 Vec::new()
738 };
739
740 Ok(IsRefBinding {
741 derived_scan_index: entry.scan_index,
742 rule_name: is_ref.rule_name.clone(),
743 is_self_ref: entry.is_self_ref,
744 negated: is_ref.negated,
745 anti_join_cols,
746 })
747 })
748 .collect()
749}
750
751fn convert_fold_bindings(
757 fold_bindings: &[(String, Expr)],
758 yield_schema: &[LocyYieldColumn],
759) -> DFResult<Vec<FoldBinding>> {
760 fold_bindings
761 .iter()
762 .map(|(name, expr)| {
763 let (kind, _input_col_name) = parse_fold_aggregate(expr)?;
764 let input_col_index = yield_schema
767 .iter()
768 .position(|yc| yc.name == *name)
769 .ok_or_else(|| {
770 datafusion::error::DataFusionError::Plan(format!(
771 "FOLD column '{}' not found in yield schema",
772 name
773 ))
774 })?;
775 Ok(FoldBinding {
776 output_name: name.clone(),
777 kind,
778 input_col_index,
779 })
780 })
781 .collect()
782}
783
784fn parse_fold_aggregate(expr: &Expr) -> DFResult<(FoldAggKind, String)> {
786 match expr {
787 Expr::FunctionCall { name, args, .. } => {
788 let kind = match name.to_uppercase().as_str() {
789 "SUM" | "MSUM" => FoldAggKind::Sum,
790 "MAX" | "MMAX" => FoldAggKind::Max,
791 "MIN" | "MMIN" => FoldAggKind::Min,
792 "COUNT" | "MCOUNT" => FoldAggKind::Count,
793 "AVG" => FoldAggKind::Avg,
794 "COLLECT" => FoldAggKind::Collect,
795 _ => {
796 return Err(datafusion::error::DataFusionError::Plan(format!(
797 "Unknown FOLD aggregate function: {}",
798 name
799 )));
800 }
801 };
802 let col_name = match args.first() {
803 Some(Expr::Variable(v)) => v.clone(),
804 Some(Expr::Property(_, prop)) => prop.clone(),
805 _ => {
806 return Err(datafusion::error::DataFusionError::Plan(
807 "FOLD aggregate argument must be a variable or property reference"
808 .to_string(),
809 ));
810 }
811 };
812 Ok((kind, col_name))
813 }
814 _ => Err(datafusion::error::DataFusionError::Plan(
815 "FOLD binding must be a function call (e.g., SUM(x))".to_string(),
816 )),
817 }
818}
819
820fn convert_best_by_criteria(
827 criteria: &[(Expr, bool)],
828 yield_schema: &[LocyYieldColumn],
829) -> DFResult<Vec<SortCriterion>> {
830 criteria
831 .iter()
832 .map(|(expr, ascending)| {
833 let col_name = match expr {
834 Expr::Property(_, prop) => prop.clone(),
835 Expr::Variable(v) => v.clone(),
836 _ => {
837 return Err(datafusion::error::DataFusionError::Plan(
838 "BEST BY criterion must be a variable or property reference".to_string(),
839 ));
840 }
841 };
842 let col_index = yield_schema
844 .iter()
845 .position(|yc| yc.name == col_name)
846 .or_else(|| {
847 let short_name = col_name.rsplit('.').next().unwrap_or(&col_name);
848 yield_schema.iter().position(|yc| yc.name == short_name)
849 })
850 .ok_or_else(|| {
851 datafusion::error::DataFusionError::Plan(format!(
852 "BEST BY column '{}' not found in yield schema",
853 col_name
854 ))
855 })?;
856 Ok(SortCriterion {
857 col_index,
858 ascending: *ascending,
859 nulls_first: false,
860 })
861 })
862 .collect()
863}
864
865fn yield_columns_to_arrow_schema(columns: &[LocyYieldColumn]) -> ArrowSchema {
871 let fields: Vec<Arc<Field>> = columns
872 .iter()
873 .map(|yc| Arc::new(Field::new(&yc.name, yc.data_type.clone(), true)))
874 .collect();
875 ArrowSchema::new(fields)
876}
877
878fn build_fixpoint_output_schema(rules: &[LocyRulePlan]) -> SchemaRef {
880 if let Some(rule) = rules.first() {
883 Arc::new(yield_columns_to_arrow_schema(&rule.yield_schema))
884 } else {
885 Arc::new(ArrowSchema::empty())
886 }
887}
888
889fn build_stats_batch(
891 derived_store: &DerivedStore,
892 _strata: &[LocyStratum],
893 output_schema: SchemaRef,
894) -> RecordBatch {
895 let mut rule_names: Vec<String> = derived_store.rule_names().map(String::from).collect();
897 rule_names.sort();
898
899 let name_col: arrow_array::StringArray = rule_names.iter().map(|s| Some(s.as_str())).collect();
900 let count_col: arrow_array::Int64Array = rule_names
901 .iter()
902 .map(|name| Some(derived_store.fact_count(name) as i64))
903 .collect();
904
905 let stats_schema = stats_schema();
906 RecordBatch::try_new(stats_schema, vec![Arc::new(name_col), Arc::new(count_col)])
907 .unwrap_or_else(|_| RecordBatch::new_empty(output_schema))
908}
909
910pub fn stats_schema() -> SchemaRef {
912 Arc::new(ArrowSchema::new(vec![
913 Arc::new(Field::new("rule_name", DataType::Utf8, false)),
914 Arc::new(Field::new("fact_count", DataType::Int64, false)),
915 ]))
916}
917
918#[cfg(test)]
923mod tests {
924 use super::*;
925 use arrow_array::{Int64Array, LargeBinaryArray, StringArray};
926
927 #[test]
928 fn test_derived_store_insert_and_get() {
929 let mut store = DerivedStore::new();
930 assert!(store.get("test").is_none());
931
932 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
933 "x",
934 DataType::LargeBinary,
935 true,
936 ))]));
937 let batch = RecordBatch::try_new(
938 Arc::clone(&schema),
939 vec![Arc::new(LargeBinaryArray::from(vec![
940 Some(b"a" as &[u8]),
941 Some(b"b"),
942 ]))],
943 )
944 .unwrap();
945
946 store.insert("test".to_string(), vec![batch.clone()]);
947
948 let facts = store.get("test").unwrap();
949 assert_eq!(facts.len(), 1);
950 assert_eq!(facts[0].num_rows(), 2);
951 }
952
953 #[test]
954 fn test_derived_store_fact_count() {
955 let mut store = DerivedStore::new();
956 assert_eq!(store.fact_count("empty"), 0);
957
958 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
959 "x",
960 DataType::LargeBinary,
961 true,
962 ))]));
963 let batch1 = RecordBatch::try_new(
964 Arc::clone(&schema),
965 vec![Arc::new(LargeBinaryArray::from(vec![Some(b"a" as &[u8])]))],
966 )
967 .unwrap();
968 let batch2 = RecordBatch::try_new(
969 Arc::clone(&schema),
970 vec![Arc::new(LargeBinaryArray::from(vec![
971 Some(b"b" as &[u8]),
972 Some(b"c"),
973 ]))],
974 )
975 .unwrap();
976
977 store.insert("test".to_string(), vec![batch1, batch2]);
978 assert_eq!(store.fact_count("test"), 3);
979 }
980
981 #[test]
982 fn test_stats_batch_schema() {
983 let schema = stats_schema();
984 assert_eq!(schema.fields().len(), 2);
985 assert_eq!(schema.field(0).name(), "rule_name");
986 assert_eq!(schema.field(1).name(), "fact_count");
987 assert_eq!(schema.field(0).data_type(), &DataType::Utf8);
988 assert_eq!(schema.field(1).data_type(), &DataType::Int64);
989 }
990
991 #[test]
992 fn test_stats_batch_content() {
993 let mut store = DerivedStore::new();
994 let schema = Arc::new(ArrowSchema::new(vec![Arc::new(Field::new(
995 "x",
996 DataType::LargeBinary,
997 true,
998 ))]));
999 let batch = RecordBatch::try_new(
1000 Arc::clone(&schema),
1001 vec![Arc::new(LargeBinaryArray::from(vec![
1002 Some(b"a" as &[u8]),
1003 Some(b"b"),
1004 ]))],
1005 )
1006 .unwrap();
1007 store.insert("reach".to_string(), vec![batch]);
1008
1009 let output_schema = stats_schema();
1010 let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
1011 assert_eq!(stats.num_rows(), 1);
1012
1013 let names = stats
1014 .column(0)
1015 .as_any()
1016 .downcast_ref::<StringArray>()
1017 .unwrap();
1018 assert_eq!(names.value(0), "reach");
1019
1020 let counts = stats
1021 .column(1)
1022 .as_any()
1023 .downcast_ref::<Int64Array>()
1024 .unwrap();
1025 assert_eq!(counts.value(0), 2);
1026 }
1027
1028 #[test]
1029 fn test_yield_columns_to_arrow_schema() {
1030 let columns = vec![
1031 LocyYieldColumn {
1032 name: "a".to_string(),
1033 is_key: true,
1034 data_type: DataType::UInt64,
1035 },
1036 LocyYieldColumn {
1037 name: "b".to_string(),
1038 is_key: false,
1039 data_type: DataType::LargeUtf8,
1040 },
1041 LocyYieldColumn {
1042 name: "c".to_string(),
1043 is_key: true,
1044 data_type: DataType::Float64,
1045 },
1046 ];
1047
1048 let schema = yield_columns_to_arrow_schema(&columns);
1049 assert_eq!(schema.fields().len(), 3);
1050 assert_eq!(schema.field(0).name(), "a");
1051 assert_eq!(schema.field(1).name(), "b");
1052 assert_eq!(schema.field(2).name(), "c");
1053 assert_eq!(schema.field(0).data_type(), &DataType::UInt64);
1055 assert_eq!(schema.field(1).data_type(), &DataType::LargeUtf8);
1056 assert_eq!(schema.field(2).data_type(), &DataType::Float64);
1057 for field in schema.fields() {
1058 assert!(field.is_nullable());
1059 }
1060 }
1061
1062 #[test]
1063 fn test_key_column_indices() {
1064 let columns = [
1065 LocyYieldColumn {
1066 name: "a".to_string(),
1067 is_key: true,
1068 data_type: DataType::LargeBinary,
1069 },
1070 LocyYieldColumn {
1071 name: "b".to_string(),
1072 is_key: false,
1073 data_type: DataType::LargeBinary,
1074 },
1075 LocyYieldColumn {
1076 name: "c".to_string(),
1077 is_key: true,
1078 data_type: DataType::LargeBinary,
1079 },
1080 ];
1081
1082 let key_indices: Vec<usize> = columns
1083 .iter()
1084 .enumerate()
1085 .filter(|(_, yc)| yc.is_key)
1086 .map(|(i, _)| i)
1087 .collect();
1088 assert_eq!(key_indices, vec![0, 2]);
1089 }
1090
1091 #[test]
1092 fn test_parse_fold_aggregate_sum() {
1093 let expr = Expr::FunctionCall {
1094 name: "SUM".to_string(),
1095 args: vec![Expr::Variable("cost".to_string())],
1096 distinct: false,
1097 window_spec: None,
1098 };
1099 let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1100 assert!(matches!(kind, FoldAggKind::Sum));
1101 assert_eq!(col, "cost");
1102 }
1103
1104 #[test]
1105 fn test_parse_fold_aggregate_monotonic() {
1106 let expr = Expr::FunctionCall {
1107 name: "MMAX".to_string(),
1108 args: vec![Expr::Variable("score".to_string())],
1109 distinct: false,
1110 window_spec: None,
1111 };
1112 let (kind, col) = parse_fold_aggregate(&expr).unwrap();
1113 assert!(matches!(kind, FoldAggKind::Max));
1114 assert_eq!(col, "score");
1115 }
1116
1117 #[test]
1118 fn test_parse_fold_aggregate_unknown() {
1119 let expr = Expr::FunctionCall {
1120 name: "UNKNOWN_AGG".to_string(),
1121 args: vec![Expr::Variable("x".to_string())],
1122 distinct: false,
1123 window_spec: None,
1124 };
1125 assert!(parse_fold_aggregate(&expr).is_err());
1126 }
1127
1128 #[test]
1129 fn test_no_commands_returns_stats() {
1130 let store = DerivedStore::new();
1131 let output_schema = stats_schema();
1132 let stats = build_stats_batch(&store, &[], Arc::clone(&output_schema));
1133 assert_eq!(stats.num_rows(), 0);
1135 }
1136}