1use crate::query::df_graph::GraphExecutionContext;
17use crate::query::df_graph::common::{
18 compute_plan_properties, evaluate_simple_expr, labels_data_type,
19};
20use crate::query::df_graph::scan::resolve_property_type;
21use arrow_array::builder::{
22 BooleanBuilder, Float32Builder, Float64Builder, Int64Builder, StringBuilder, UInt64Builder,
23};
24use arrow_array::{ArrayRef, RecordBatch};
25use arrow_schema::{DataType, Field, Schema, SchemaRef};
26use datafusion::common::Result as DFResult;
27use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, TaskContext};
28use datafusion::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
29use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
30use futures::Stream;
31use std::any::Any;
32use std::collections::HashMap;
33use std::fmt;
34use std::pin::Pin;
35use std::sync::Arc;
36use std::task::{Context, Poll};
37use uni_common::Value;
38use uni_common::core::id::Vid;
39use uni_cypher::ast::Expr;
40
41pub(crate) fn map_yield_to_canonical(yield_name: &str) -> String {
48 match yield_name.to_lowercase().as_str() {
49 "vid" | "_vid" => "vid",
50 "distance" | "dist" | "_distance" => "distance",
51 "score" | "_score" => "score",
52 "vector_score" => "vector_score",
53 "fts_score" => "fts_score",
54 "raw_score" => "raw_score",
55 _ => "node",
56 }
57 .to_string()
58}
59
60pub struct GraphProcedureCallExec {
65 graph_ctx: Arc<GraphExecutionContext>,
67
68 procedure_name: String,
70
71 arguments: Vec<Expr>,
73
74 yield_items: Vec<(String, Option<String>)>,
76
77 params: HashMap<String, Value>,
79
80 target_properties: HashMap<String, Vec<String>>,
82
83 schema: SchemaRef,
85
86 properties: PlanProperties,
88
89 metrics: ExecutionPlanMetricsSet,
91}
92
93impl fmt::Debug for GraphProcedureCallExec {
94 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
95 f.debug_struct("GraphProcedureCallExec")
96 .field("procedure_name", &self.procedure_name)
97 .field("yield_items", &self.yield_items)
98 .finish()
99 }
100}
101
102impl GraphProcedureCallExec {
103 pub fn new(
105 graph_ctx: Arc<GraphExecutionContext>,
106 procedure_name: String,
107 arguments: Vec<Expr>,
108 yield_items: Vec<(String, Option<String>)>,
109 params: HashMap<String, Value>,
110 target_properties: HashMap<String, Vec<String>>,
111 ) -> Self {
112 let schema = Self::build_schema(
113 &procedure_name,
114 &yield_items,
115 &target_properties,
116 &graph_ctx,
117 );
118 let properties = compute_plan_properties(schema.clone());
119
120 Self {
121 graph_ctx,
122 procedure_name,
123 arguments,
124 yield_items,
125 params,
126 target_properties,
127 schema,
128 properties,
129 metrics: ExecutionPlanMetricsSet::new(),
130 }
131 }
132
133 fn build_schema(
135 procedure_name: &str,
136 yield_items: &[(String, Option<String>)],
137 target_properties: &HashMap<String, Vec<String>>,
138 graph_ctx: &GraphExecutionContext,
139 ) -> SchemaRef {
140 let mut fields = Vec::new();
141
142 match procedure_name {
143 "uni.schema.labels" => {
144 for (name, alias) in yield_items {
146 let col_name = alias.as_ref().unwrap_or(name);
147 let data_type = match name.as_str() {
148 "label" => DataType::Utf8,
149 "propertyCount" | "nodeCount" | "indexCount" => DataType::Int64,
150 _ => DataType::Utf8,
151 };
152 fields.push(Field::new(col_name, data_type, true));
153 }
154 }
155 "uni.schema.edgeTypes" | "uni.schema.relationshipTypes" => {
156 for (name, alias) in yield_items {
157 let col_name = alias.as_ref().unwrap_or(name);
158 let data_type = match name.as_str() {
159 "type" | "relationshipType" => DataType::Utf8,
160 "propertyCount" => DataType::Int64,
161 "sourceLabels" | "targetLabels" => DataType::Utf8, _ => DataType::Utf8,
163 };
164 fields.push(Field::new(col_name, data_type, true));
165 }
166 }
167 "uni.schema.indexes" => {
168 for (name, alias) in yield_items {
169 let col_name = alias.as_ref().unwrap_or(name);
170 let data_type = match name.as_str() {
171 "name" | "type" | "label" | "state" | "properties" => DataType::Utf8,
172 _ => DataType::Utf8,
173 };
174 fields.push(Field::new(col_name, data_type, true));
175 }
176 }
177 "uni.schema.constraints" => {
178 for (name, alias) in yield_items {
179 let col_name = alias.as_ref().unwrap_or(name);
180 let data_type = match name.as_str() {
181 "enabled" => DataType::Boolean,
182 _ => DataType::Utf8,
183 };
184 fields.push(Field::new(col_name, data_type, true));
185 }
186 }
187 "uni.schema.labelInfo" => {
188 for (name, alias) in yield_items {
189 let col_name = alias.as_ref().unwrap_or(name);
190 let data_type = match name.as_str() {
191 "property" | "dataType" => DataType::Utf8,
192 "nullable" | "indexed" | "unique" => DataType::Boolean,
193 _ => DataType::Utf8,
194 };
195 fields.push(Field::new(col_name, data_type, true));
196 }
197 }
198 "uni.vector.query" | "uni.fts.query" | "uni.search" => {
199 for (name, alias) in yield_items {
201 let output_name = alias.as_ref().unwrap_or(name);
202 let canonical = map_yield_to_canonical(name);
203
204 match canonical.as_str() {
205 "node" => {
206 fields.push(Field::new(
208 format!("{}._vid", output_name),
209 DataType::UInt64,
210 false,
211 ));
212 fields.push(Field::new(output_name, DataType::Utf8, false));
213 fields.push(Field::new(
214 format!("{}._labels", output_name),
215 labels_data_type(),
216 true,
217 ));
218
219 if let Some(props) = target_properties.get(output_name.as_str()) {
221 let uni_schema = graph_ctx.storage().schema_manager().schema();
222 for prop_name in props {
225 let col_name = format!("{}.{}", output_name, prop_name);
226 let arrow_type = resolve_property_type(prop_name, None);
227 let resolved_type = uni_schema
229 .properties
230 .values()
231 .find_map(|label_props| {
232 label_props.get(prop_name.as_str()).map(|_| {
233 resolve_property_type(prop_name, Some(label_props))
234 })
235 })
236 .unwrap_or(arrow_type);
237 fields.push(Field::new(&col_name, resolved_type, true));
238 }
239 }
240 }
241 "distance" => {
242 fields.push(Field::new(output_name, DataType::Float64, true));
243 }
244 "score" | "vector_score" | "fts_score" | "raw_score" => {
245 fields.push(Field::new(output_name, DataType::Float32, true));
246 }
247 "vid" => {
248 fields.push(Field::new(output_name, DataType::Int64, true));
249 }
250 _ => {
251 fields.push(Field::new(output_name, DataType::Utf8, true));
252 }
253 }
254 }
255 }
256 name if name.starts_with("uni.algo.") => {
257 if let Some(registry) = graph_ctx.algo_registry()
258 && let Some(procedure) = registry.get(name)
259 {
260 let sig = procedure.signature();
261 for (yield_name, alias) in yield_items {
262 let col_name = alias.as_ref().unwrap_or(yield_name);
263 let yield_vt = sig.yields.iter().find(|(n, _)| *n == yield_name.as_str());
264 let data_type = yield_vt
265 .map(|(_, vt)| value_type_to_arrow(vt))
266 .unwrap_or(DataType::Utf8);
267 let mut field = Field::new(col_name, data_type, true);
268 if yield_vt.is_some_and(|(_, vt)| is_complex_value_type(vt)) {
271 let mut metadata = std::collections::HashMap::new();
272 metadata.insert("cv_encoded".to_string(), "true".to_string());
273 field = field.with_metadata(metadata);
274 }
275 fields.push(field);
276 }
277 } else {
278 for (name, alias) in yield_items {
280 let col_name = alias.as_ref().unwrap_or(name);
281 fields.push(Field::new(col_name, DataType::Utf8, true));
282 }
283 }
284 }
285 _ => {
286 if let Some(registry) = graph_ctx.procedure_registry()
288 && let Some(proc_def) = registry.get(procedure_name)
289 {
290 for (name, alias) in yield_items {
291 let col_name = alias.as_ref().unwrap_or(name);
292 let data_type = proc_def
294 .outputs
295 .iter()
296 .find(|o| o.name == *name)
297 .map(|o| procedure_value_type_to_arrow(&o.output_type))
298 .unwrap_or(DataType::Utf8);
299 fields.push(Field::new(col_name, data_type, true));
300 }
301 } else if yield_items.is_empty() {
302 } else {
304 for (name, alias) in yield_items {
306 let col_name = alias.as_ref().unwrap_or(name);
307 fields.push(Field::new(col_name, DataType::Utf8, true));
308 }
309 }
310 }
311 }
312
313 Arc::new(Schema::new(fields))
314 }
315}
316
317fn value_type_to_arrow(vt: &uni_algo::algo::procedures::ValueType) -> DataType {
319 use uni_algo::algo::procedures::ValueType;
320 match vt {
321 ValueType::Int => DataType::Int64,
322 ValueType::Float => DataType::Float64,
323 ValueType::String => DataType::Utf8,
324 ValueType::Bool => DataType::Boolean,
325 ValueType::List
326 | ValueType::Map
327 | ValueType::Node
328 | ValueType::Relationship
329 | ValueType::Path
330 | ValueType::Any => DataType::Utf8,
331 }
332}
333
334fn is_complex_value_type(vt: &uni_algo::algo::procedures::ValueType) -> bool {
337 use uni_algo::algo::procedures::ValueType;
338 matches!(
339 vt,
340 ValueType::List
341 | ValueType::Map
342 | ValueType::Node
343 | ValueType::Relationship
344 | ValueType::Path
345 )
346}
347
348fn procedure_value_type_to_arrow(
350 vt: &crate::query::executor::procedure::ProcedureValueType,
351) -> DataType {
352 use crate::query::executor::procedure::ProcedureValueType;
353 match vt {
354 ProcedureValueType::Integer => DataType::Int64,
355 ProcedureValueType::Float | ProcedureValueType::Number => DataType::Float64,
356 ProcedureValueType::Boolean => DataType::Boolean,
357 ProcedureValueType::String | ProcedureValueType::Any => DataType::Utf8,
358 }
359}
360
361impl DisplayAs for GraphProcedureCallExec {
362 fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter<'_>) -> fmt::Result {
363 write!(
364 f,
365 "GraphProcedureCallExec: procedure={}",
366 self.procedure_name
367 )
368 }
369}
370
371impl ExecutionPlan for GraphProcedureCallExec {
372 fn name(&self) -> &str {
373 "GraphProcedureCallExec"
374 }
375
376 fn as_any(&self) -> &dyn Any {
377 self
378 }
379
380 fn schema(&self) -> SchemaRef {
381 self.schema.clone()
382 }
383
384 fn properties(&self) -> &PlanProperties {
385 &self.properties
386 }
387
388 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
389 vec![]
390 }
391
392 fn with_new_children(
393 self: Arc<Self>,
394 children: Vec<Arc<dyn ExecutionPlan>>,
395 ) -> DFResult<Arc<dyn ExecutionPlan>> {
396 if !children.is_empty() {
397 return Err(datafusion::error::DataFusionError::Internal(
398 "GraphProcedureCallExec has no children".to_string(),
399 ));
400 }
401 Ok(self)
402 }
403
404 fn execute(
405 &self,
406 partition: usize,
407 _context: Arc<TaskContext>,
408 ) -> DFResult<SendableRecordBatchStream> {
409 let metrics = BaselineMetrics::new(&self.metrics, partition);
410
411 let mut evaluated_args = Vec::with_capacity(self.arguments.len());
413 for arg in &self.arguments {
414 evaluated_args.push(evaluate_simple_expr(arg, &self.params)?);
415 }
416
417 Ok(Box::pin(ProcedureCallStream::new(
418 self.graph_ctx.clone(),
419 self.procedure_name.clone(),
420 evaluated_args,
421 self.yield_items.clone(),
422 self.target_properties.clone(),
423 self.schema.clone(),
424 metrics,
425 )))
426 }
427
428 fn metrics(&self) -> Option<MetricsSet> {
429 Some(self.metrics.clone_inner())
430 }
431}
432
433enum ProcedureCallState {
439 Init,
441 Executing(Pin<Box<dyn std::future::Future<Output = DFResult<Option<RecordBatch>>> + Send>>),
443 Done,
445}
446
447struct ProcedureCallStream {
449 graph_ctx: Arc<GraphExecutionContext>,
450 procedure_name: String,
451 evaluated_args: Vec<Value>,
452 yield_items: Vec<(String, Option<String>)>,
453 target_properties: HashMap<String, Vec<String>>,
454 schema: SchemaRef,
455 state: ProcedureCallState,
456 metrics: BaselineMetrics,
457}
458
459impl ProcedureCallStream {
460 fn new(
461 graph_ctx: Arc<GraphExecutionContext>,
462 procedure_name: String,
463 evaluated_args: Vec<Value>,
464 yield_items: Vec<(String, Option<String>)>,
465 target_properties: HashMap<String, Vec<String>>,
466 schema: SchemaRef,
467 metrics: BaselineMetrics,
468 ) -> Self {
469 Self {
470 graph_ctx,
471 procedure_name,
472 evaluated_args,
473 yield_items,
474 target_properties,
475 schema,
476 state: ProcedureCallState::Init,
477 metrics,
478 }
479 }
480}
481
482impl Stream for ProcedureCallStream {
483 type Item = DFResult<RecordBatch>;
484
485 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
486 loop {
487 let state = std::mem::replace(&mut self.state, ProcedureCallState::Done);
488
489 match state {
490 ProcedureCallState::Init => {
491 let graph_ctx = self.graph_ctx.clone();
492 let procedure_name = self.procedure_name.clone();
493 let evaluated_args = self.evaluated_args.clone();
494 let yield_items = self.yield_items.clone();
495 let target_properties = self.target_properties.clone();
496 let schema = self.schema.clone();
497
498 let fut = async move {
499 graph_ctx.check_timeout().map_err(|e| {
500 datafusion::error::DataFusionError::Execution(e.to_string())
501 })?;
502
503 execute_procedure(
504 &graph_ctx,
505 &procedure_name,
506 &evaluated_args,
507 &yield_items,
508 &target_properties,
509 &schema,
510 )
511 .await
512 };
513
514 self.state = ProcedureCallState::Executing(Box::pin(fut));
515 }
516 ProcedureCallState::Executing(mut fut) => match fut.as_mut().poll(cx) {
517 Poll::Ready(Ok(batch)) => {
518 self.state = ProcedureCallState::Done;
519 self.metrics
520 .record_output(batch.as_ref().map(|b| b.num_rows()).unwrap_or(0));
521 return Poll::Ready(batch.map(Ok));
522 }
523 Poll::Ready(Err(e)) => {
524 self.state = ProcedureCallState::Done;
525 return Poll::Ready(Some(Err(e)));
526 }
527 Poll::Pending => {
528 self.state = ProcedureCallState::Executing(fut);
529 return Poll::Pending;
530 }
531 },
532 ProcedureCallState::Done => {
533 return Poll::Ready(None);
534 }
535 }
536 }
537 }
538}
539
540impl RecordBatchStream for ProcedureCallStream {
541 fn schema(&self) -> SchemaRef {
542 self.schema.clone()
543 }
544}
545
546async fn execute_procedure(
552 graph_ctx: &GraphExecutionContext,
553 procedure_name: &str,
554 args: &[Value],
555 yield_items: &[(String, Option<String>)],
556 target_properties: &HashMap<String, Vec<String>>,
557 schema: &SchemaRef,
558) -> DFResult<Option<RecordBatch>> {
559 match procedure_name {
560 "uni.schema.labels" => execute_schema_labels(graph_ctx, yield_items, schema).await,
561 "uni.schema.edgeTypes" | "uni.schema.relationshipTypes" => {
562 execute_schema_edge_types(graph_ctx, yield_items, schema).await
563 }
564 "uni.schema.indexes" => execute_schema_indexes(graph_ctx, yield_items, schema).await,
565 "uni.schema.constraints" => {
566 execute_schema_constraints(graph_ctx, yield_items, schema).await
567 }
568 "uni.schema.labelInfo" => {
569 execute_schema_label_info(graph_ctx, args, yield_items, schema).await
570 }
571 "uni.vector.query" => {
572 execute_vector_query(graph_ctx, args, yield_items, target_properties, schema).await
573 }
574 "uni.fts.query" => {
575 execute_fts_query(graph_ctx, args, yield_items, target_properties, schema).await
576 }
577 "uni.search" => {
578 execute_hybrid_search(graph_ctx, args, yield_items, target_properties, schema).await
579 }
580 name if name.starts_with("uni.algo.") => {
581 execute_algo_procedure(graph_ctx, name, args, yield_items, schema).await
582 }
583 _ => {
584 execute_registered_procedure(graph_ctx, procedure_name, args, yield_items, schema).await
585 }
586 }
587}
588
589async fn execute_schema_labels(
594 graph_ctx: &GraphExecutionContext,
595 yield_items: &[(String, Option<String>)],
596 schema: &SchemaRef,
597) -> DFResult<Option<RecordBatch>> {
598 let uni_schema = graph_ctx.storage().schema_manager().schema();
599 let storage = graph_ctx.storage();
600
601 let mut rows: Vec<HashMap<String, Value>> = Vec::new();
603 for label_name in uni_schema.labels.keys() {
604 let mut row = HashMap::new();
605 row.insert("label".to_string(), Value::String(label_name.clone()));
606
607 let prop_count = uni_schema
608 .properties
609 .get(label_name)
610 .map(|p| p.len())
611 .unwrap_or(0);
612 row.insert("propertyCount".to_string(), Value::Int(prop_count as i64));
613
614 let node_count = if let Ok(ds) = storage.vertex_dataset(label_name) {
615 if let Ok(raw) = ds.open_raw().await {
616 raw.count_rows(None).await.unwrap_or(0)
617 } else {
618 0
619 }
620 } else {
621 0
622 };
623 row.insert("nodeCount".to_string(), Value::Int(node_count as i64));
624
625 let idx_count = uni_schema
626 .indexes
627 .iter()
628 .filter(|i| i.label() == label_name)
629 .count();
630 row.insert("indexCount".to_string(), Value::Int(idx_count as i64));
631
632 rows.push(row);
633 }
634
635 build_scalar_batch(&rows, yield_items, schema)
636}
637
638async fn execute_schema_edge_types(
639 graph_ctx: &GraphExecutionContext,
640 yield_items: &[(String, Option<String>)],
641 schema: &SchemaRef,
642) -> DFResult<Option<RecordBatch>> {
643 let uni_schema = graph_ctx.storage().schema_manager().schema();
644
645 let mut rows: Vec<HashMap<String, Value>> = Vec::new();
646 for (type_name, meta) in &uni_schema.edge_types {
647 let mut row = HashMap::new();
648 row.insert("type".to_string(), Value::String(type_name.clone()));
649 row.insert(
650 "relationshipType".to_string(),
651 Value::String(type_name.clone()),
652 );
653 row.insert(
654 "sourceLabels".to_string(),
655 Value::String(format!("{:?}", meta.src_labels)),
656 );
657 row.insert(
658 "targetLabels".to_string(),
659 Value::String(format!("{:?}", meta.dst_labels)),
660 );
661
662 let prop_count = uni_schema
663 .properties
664 .get(type_name)
665 .map(|p| p.len())
666 .unwrap_or(0);
667 row.insert("propertyCount".to_string(), Value::Int(prop_count as i64));
668
669 rows.push(row);
670 }
671
672 build_scalar_batch(&rows, yield_items, schema)
673}
674
675async fn execute_schema_indexes(
676 graph_ctx: &GraphExecutionContext,
677 yield_items: &[(String, Option<String>)],
678 schema: &SchemaRef,
679) -> DFResult<Option<RecordBatch>> {
680 let uni_schema = graph_ctx.storage().schema_manager().schema();
681
682 let mut rows: Vec<HashMap<String, Value>> = Vec::new();
683 for idx in &uni_schema.indexes {
684 use uni_common::core::schema::IndexDefinition;
685
686 let (type_name, properties_json) = match &idx {
688 IndexDefinition::Vector(v) => (
689 "VECTOR",
690 serde_json::to_string(&[&v.property]).unwrap_or_default(),
691 ),
692 IndexDefinition::FullText(f) => (
693 "FULLTEXT",
694 serde_json::to_string(&f.properties).unwrap_or_default(),
695 ),
696 IndexDefinition::Scalar(s) => (
697 "SCALAR",
698 serde_json::to_string(&s.properties).unwrap_or_default(),
699 ),
700 IndexDefinition::JsonFullText(j) => (
701 "JSON_FTS",
702 serde_json::to_string(&[&j.column]).unwrap_or_default(),
703 ),
704 IndexDefinition::Inverted(inv) => (
705 "INVERTED",
706 serde_json::to_string(&[&inv.property]).unwrap_or_default(),
707 ),
708 _ => ("UNKNOWN", String::new()),
709 };
710
711 let row = HashMap::from([
712 ("state".to_string(), Value::String("ONLINE".to_string())),
713 ("name".to_string(), Value::String(idx.name().to_string())),
714 ("type".to_string(), Value::String(type_name.to_string())),
715 ("label".to_string(), Value::String(idx.label().to_string())),
716 ("properties".to_string(), Value::String(properties_json)),
717 ]);
718 rows.push(row);
719 }
720
721 build_scalar_batch(&rows, yield_items, schema)
722}
723
724async fn execute_schema_constraints(
725 graph_ctx: &GraphExecutionContext,
726 yield_items: &[(String, Option<String>)],
727 schema: &SchemaRef,
728) -> DFResult<Option<RecordBatch>> {
729 let uni_schema = graph_ctx.storage().schema_manager().schema();
730
731 let mut rows: Vec<HashMap<String, Value>> = Vec::new();
732 for c in &uni_schema.constraints {
733 let mut row = HashMap::new();
734 row.insert("name".to_string(), Value::String(c.name.clone()));
735 row.insert("enabled".to_string(), Value::Bool(c.enabled));
736
737 match &c.constraint_type {
738 uni_common::core::schema::ConstraintType::Unique { properties } => {
739 row.insert("type".to_string(), Value::String("UNIQUE".to_string()));
740 row.insert(
741 "properties".to_string(),
742 Value::String(serde_json::to_string(&properties).unwrap_or_default()),
743 );
744 }
745 uni_common::core::schema::ConstraintType::Exists { property } => {
746 row.insert("type".to_string(), Value::String("EXISTS".to_string()));
747 row.insert(
748 "properties".to_string(),
749 Value::String(serde_json::to_string(&[&property]).unwrap_or_default()),
750 );
751 }
752 uni_common::core::schema::ConstraintType::Check { expression } => {
753 row.insert("type".to_string(), Value::String("CHECK".to_string()));
754 row.insert("expression".to_string(), Value::String(expression.clone()));
755 }
756 _ => {
757 row.insert("type".to_string(), Value::String("UNKNOWN".to_string()));
758 }
759 }
760
761 match &c.target {
762 uni_common::core::schema::ConstraintTarget::Label(l) => {
763 row.insert("label".to_string(), Value::String(l.clone()));
764 }
765 uni_common::core::schema::ConstraintTarget::EdgeType(t) => {
766 row.insert("relationshipType".to_string(), Value::String(t.clone()));
767 }
768 _ => {
769 row.insert("target".to_string(), Value::String("UNKNOWN".to_string()));
770 }
771 }
772
773 rows.push(row);
774 }
775
776 build_scalar_batch(&rows, yield_items, schema)
777}
778
779async fn execute_schema_label_info(
780 graph_ctx: &GraphExecutionContext,
781 args: &[Value],
782 yield_items: &[(String, Option<String>)],
783 schema: &SchemaRef,
784) -> DFResult<Option<RecordBatch>> {
785 let label_name = require_string_arg(args, 0, "uni.schema.labelInfo: first argument (label)")?;
786
787 let uni_schema = graph_ctx.storage().schema_manager().schema();
788
789 let mut rows: Vec<HashMap<String, Value>> = Vec::new();
790 if let Some(props) = uni_schema.properties.get(&label_name) {
791 for (prop_name, prop_meta) in props {
792 let mut row = HashMap::new();
793 row.insert("property".to_string(), Value::String(prop_name.clone()));
794 row.insert(
795 "dataType".to_string(),
796 Value::String(format!("{:?}", prop_meta.r#type)),
797 );
798 row.insert("nullable".to_string(), Value::Bool(prop_meta.nullable));
799
800 let is_indexed = uni_schema.indexes.iter().any(|idx| match idx {
801 uni_common::core::schema::IndexDefinition::Vector(v) => {
802 v.label == label_name && v.property == *prop_name
803 }
804 uni_common::core::schema::IndexDefinition::Scalar(s) => {
805 s.label == label_name && s.properties.contains(prop_name)
806 }
807 uni_common::core::schema::IndexDefinition::FullText(f) => {
808 f.label == label_name && f.properties.contains(prop_name)
809 }
810 uni_common::core::schema::IndexDefinition::Inverted(inv) => {
811 inv.label == label_name && inv.property == *prop_name
812 }
813 uni_common::core::schema::IndexDefinition::JsonFullText(j) => j.label == label_name,
814 _ => false,
815 });
816 row.insert("indexed".to_string(), Value::Bool(is_indexed));
817
818 let unique = uni_schema.constraints.iter().any(|c| {
819 if let uni_common::core::schema::ConstraintTarget::Label(l) = &c.target
820 && l == &label_name
821 && c.enabled
822 && let uni_common::core::schema::ConstraintType::Unique { properties } =
823 &c.constraint_type
824 {
825 return properties.contains(prop_name);
826 }
827 false
828 });
829 row.insert("unique".to_string(), Value::Bool(unique));
830
831 rows.push(row);
832 }
833 }
834
835 build_scalar_batch(&rows, yield_items, schema)
836}
837
838fn build_typed_column<'a>(
843 values: impl Iterator<Item = Option<&'a Value>>,
844 num_rows: usize,
845 data_type: &DataType,
846) -> ArrayRef {
847 match data_type {
848 DataType::Int64 => {
849 let mut builder = Int64Builder::with_capacity(num_rows);
850 for val in values {
851 match val.and_then(|v| v.as_i64()) {
852 Some(i) => builder.append_value(i),
853 None => builder.append_null(),
854 }
855 }
856 Arc::new(builder.finish())
857 }
858 DataType::Float64 => {
859 let mut builder = Float64Builder::with_capacity(num_rows);
860 for val in values {
861 match val.and_then(|v| v.as_f64()) {
862 Some(f) => builder.append_value(f),
863 None => builder.append_null(),
864 }
865 }
866 Arc::new(builder.finish())
867 }
868 DataType::Boolean => {
869 let mut builder = BooleanBuilder::with_capacity(num_rows);
870 for val in values {
871 match val.and_then(|v| v.as_bool()) {
872 Some(b) => builder.append_value(b),
873 None => builder.append_null(),
874 }
875 }
876 Arc::new(builder.finish())
877 }
878 _ => {
879 let mut builder = StringBuilder::with_capacity(num_rows, num_rows * 32);
881 for val in values {
882 match val {
883 Some(Value::String(s)) => builder.append_value(s),
884 Some(v) => builder.append_value(format!("{v}")),
885 None => builder.append_null(),
886 }
887 }
888 Arc::new(builder.finish())
889 }
890 }
891}
892
893fn create_empty_batch(schema: SchemaRef) -> DFResult<RecordBatch> {
899 if schema.fields().is_empty() {
900 let options = arrow_array::RecordBatchOptions::new().with_row_count(Some(0));
901 RecordBatch::try_new_with_options(schema, vec![], &options)
902 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))
903 } else {
904 Ok(RecordBatch::new_empty(schema))
905 }
906}
907
908fn build_scalar_batch(
910 rows: &[HashMap<String, Value>],
911 yield_items: &[(String, Option<String>)],
912 schema: &SchemaRef,
913) -> DFResult<Option<RecordBatch>> {
914 if rows.is_empty() {
915 return Ok(Some(create_empty_batch(schema.clone())?));
916 }
917
918 let num_rows = rows.len();
919 let mut columns: Vec<ArrayRef> = Vec::new();
920
921 for (idx, (name, _alias)) in yield_items.iter().enumerate() {
922 let field = schema.field(idx);
923 let values = rows.iter().map(|row| row.get(name));
924 columns.push(build_typed_column(values, num_rows, field.data_type()));
925 }
926
927 let batch = RecordBatch::try_new(schema.clone(), columns)
928 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
929 Ok(Some(batch))
930}
931
932async fn execute_registered_procedure(
941 graph_ctx: &GraphExecutionContext,
942 procedure_name: &str,
943 args: &[Value],
944 yield_items: &[(String, Option<String>)],
945 schema: &SchemaRef,
946) -> DFResult<Option<RecordBatch>> {
947 let registry = graph_ctx.procedure_registry().ok_or_else(|| {
948 datafusion::error::DataFusionError::Execution(format!(
949 "Procedure '{}' not supported in DataFusion engine (no procedure registry)",
950 procedure_name
951 ))
952 })?;
953
954 let proc_def = registry.get(procedure_name).ok_or_else(|| {
955 datafusion::error::DataFusionError::Execution(format!(
956 "ProcedureNotFound: Unknown procedure '{}'",
957 procedure_name
958 ))
959 })?;
960
961 if args.len() != proc_def.params.len() {
963 return Err(datafusion::error::DataFusionError::Execution(format!(
964 "InvalidNumberOfArguments: Procedure '{}' expects {} argument(s), got {}",
965 proc_def.name,
966 proc_def.params.len(),
967 args.len()
968 )));
969 }
970
971 for (i, (arg_val, param)) in args.iter().zip(&proc_def.params).enumerate() {
973 if !arg_val.is_null() && !check_proc_type_compatible(arg_val, ¶m.param_type) {
974 return Err(datafusion::error::DataFusionError::Execution(format!(
975 "InvalidArgumentType: Argument {} ('{}') of procedure '{}' has incompatible type",
976 i, param.name, proc_def.name
977 )));
978 }
979 }
980
981 let filtered: Vec<&HashMap<String, Value>> = proc_def
983 .data
984 .iter()
985 .filter(|row| {
986 for (param, arg_val) in proc_def.params.iter().zip(args) {
987 if let Some(row_val) = row.get(¶m.name)
988 && !proc_values_match(row_val, arg_val)
989 {
990 return false;
991 }
992 }
993 true
994 })
995 .collect();
996
997 if yield_items.is_empty() {
999 return Ok(Some(create_empty_batch(schema.clone())?));
1000 }
1001
1002 if filtered.is_empty() {
1003 return Ok(Some(create_empty_batch(schema.clone())?));
1004 }
1005
1006 let num_rows = filtered.len();
1009 let mut columns: Vec<ArrayRef> = Vec::new();
1010
1011 for (idx, (name, _alias)) in yield_items.iter().enumerate() {
1012 let field = schema.field(idx);
1013 let values = filtered.iter().map(|row| row.get(name.as_str()));
1014 columns.push(build_typed_column(values, num_rows, field.data_type()));
1015 }
1016
1017 let batch = RecordBatch::try_new(schema.clone(), columns)
1018 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
1019 Ok(Some(batch))
1020}
1021
1022fn check_proc_type_compatible(
1024 val: &Value,
1025 expected: &crate::query::executor::procedure::ProcedureValueType,
1026) -> bool {
1027 use crate::query::executor::procedure::ProcedureValueType;
1028 match expected {
1029 ProcedureValueType::Any => true,
1030 ProcedureValueType::String => val.is_string(),
1031 ProcedureValueType::Boolean => val.is_bool(),
1032 ProcedureValueType::Integer => val.is_i64(),
1033 ProcedureValueType::Float => val.is_f64() || val.is_i64(),
1034 ProcedureValueType::Number => val.is_number(),
1035 }
1036}
1037
1038fn proc_values_match(row_val: &Value, arg_val: &Value) -> bool {
1040 if arg_val.is_null() || row_val.is_null() {
1041 return arg_val.is_null() && row_val.is_null();
1042 }
1043 if let (Some(a), Some(b)) = (row_val.as_f64(), arg_val.as_f64()) {
1045 return (a - b).abs() < f64::EPSILON;
1046 }
1047 row_val == arg_val
1048}
1049
1050async fn execute_algo_procedure(
1055 graph_ctx: &GraphExecutionContext,
1056 procedure_name: &str,
1057 args: &[Value],
1058 yield_items: &[(String, Option<String>)],
1059 schema: &SchemaRef,
1060) -> DFResult<Option<RecordBatch>> {
1061 use futures::StreamExt;
1062 use uni_algo::algo::procedures::AlgoContext;
1063
1064 let registry = graph_ctx.algo_registry().ok_or_else(|| {
1065 datafusion::error::DataFusionError::Execution(
1066 "Algorithm registry not available".to_string(),
1067 )
1068 })?;
1069
1070 let procedure = registry.get(procedure_name).ok_or_else(|| {
1071 datafusion::error::DataFusionError::Execution(format!(
1072 "Unknown algorithm: {}",
1073 procedure_name
1074 ))
1075 })?;
1076
1077 let signature = procedure.signature();
1078
1079 let serde_args: Vec<serde_json::Value> = args.iter().cloned().map(|v| v.into()).collect();
1083
1084 let algo_ctx = AlgoContext::new(graph_ctx.storage().clone(), None);
1086
1087 let mut stream = procedure.execute(algo_ctx, serde_args);
1089 let mut rows = Vec::new();
1090 while let Some(row_res) = stream.next().await {
1091 if rows.len() % 1000 == 0 {
1093 graph_ctx
1094 .check_timeout()
1095 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1096 }
1097 let row =
1098 row_res.map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1099 rows.push(row);
1100 }
1101
1102 build_algo_batch(&rows, &signature, yield_items, schema)
1103}
1104
1105fn json_to_value(jv: &serde_json::Value) -> Value {
1107 match jv {
1108 serde_json::Value::Null => Value::Null,
1109 serde_json::Value::Bool(b) => Value::Bool(*b),
1110 serde_json::Value::Number(n) => {
1111 if let Some(i) = n.as_i64() {
1112 Value::Int(i)
1113 } else if let Some(f) = n.as_f64() {
1114 Value::Float(f)
1115 } else {
1116 Value::Null
1117 }
1118 }
1119 serde_json::Value::String(s) => Value::String(s.clone()),
1120 other => Value::String(other.to_string()),
1121 }
1122}
1123
1124fn build_algo_batch(
1126 rows: &[uni_algo::algo::procedures::AlgoResultRow],
1127 signature: &uni_algo::algo::procedures::ProcedureSignature,
1128 yield_items: &[(String, Option<String>)],
1129 schema: &SchemaRef,
1130) -> DFResult<Option<RecordBatch>> {
1131 if rows.is_empty() {
1132 return Ok(Some(create_empty_batch(schema.clone())?));
1133 }
1134
1135 let num_rows = rows.len();
1136 let mut columns: Vec<ArrayRef> = Vec::new();
1137
1138 for (idx, (yield_name, _alias)) in yield_items.iter().enumerate() {
1139 let sig_idx = signature
1140 .yields
1141 .iter()
1142 .position(|(n, _)| *n == yield_name.as_str());
1143
1144 let uni_values: Vec<Value> = rows
1146 .iter()
1147 .map(|row| match sig_idx {
1148 Some(si) => json_to_value(&row.values[si]),
1149 None => Value::Null,
1150 })
1151 .collect();
1152
1153 let field = schema.field(idx);
1154 let values = uni_values.iter().map(Some);
1155 columns.push(build_typed_column(values, num_rows, field.data_type()));
1156 }
1157
1158 let batch = RecordBatch::try_new(schema.clone(), columns)
1159 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
1160 Ok(Some(batch))
1161}
1162
1163fn require_string_arg(args: &[Value], index: usize, description: &str) -> DFResult<String> {
1169 args.get(index)
1170 .and_then(|v| v.as_str())
1171 .map(|s| s.to_string())
1172 .ok_or_else(|| {
1173 datafusion::error::DataFusionError::Execution(format!("{description} must be a string"))
1174 })
1175}
1176
1177fn extract_optional_filter(args: &[Value], index: usize) -> Option<String> {
1180 args.get(index).and_then(|v| {
1181 if v.is_null() {
1182 None
1183 } else {
1184 v.as_str().map(|s| s.to_string())
1185 }
1186 })
1187}
1188
1189fn extract_optional_threshold(args: &[Value], index: usize) -> Option<f64> {
1192 args.get(index)
1193 .and_then(|v| if v.is_null() { None } else { v.as_f64() })
1194}
1195
1196fn require_int_arg(args: &[Value], index: usize, description: &str) -> DFResult<usize> {
1198 args.get(index)
1199 .and_then(|v| v.as_u64())
1200 .map(|v| v as usize)
1201 .ok_or_else(|| {
1202 datafusion::error::DataFusionError::Execution(format!(
1203 "{description} must be an integer"
1204 ))
1205 })
1206}
1207
1208async fn auto_embed_text(
1217 graph_ctx: &GraphExecutionContext,
1218 label: &str,
1219 property: &str,
1220 query_text: &str,
1221) -> DFResult<Vec<f32>> {
1222 let storage = graph_ctx.storage();
1223 let uni_schema = storage.schema_manager().schema();
1224 let index_config = uni_schema.vector_index_for_property(label, property);
1225
1226 let embedding_config = index_config
1227 .and_then(|cfg| cfg.embedding_config.as_ref())
1228 .ok_or_else(|| {
1229 datafusion::error::DataFusionError::Execution(format!(
1230 "Cannot auto-embed: vector index for {label}.{property} has no embedding_config. \
1231 Either provide a pre-computed vector or create the index with embedding options."
1232 ))
1233 })?;
1234
1235 let runtime = graph_ctx.xervo_runtime().ok_or_else(|| {
1236 datafusion::error::DataFusionError::Execution(
1237 "Cannot auto-embed: Uni-Xervo runtime not configured".to_string(),
1238 )
1239 })?;
1240
1241 let embedder = runtime
1242 .embedding(&embedding_config.alias)
1243 .await
1244 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1245 let embeddings = embedder
1246 .embed(vec![query_text])
1247 .await
1248 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1249 embeddings.into_iter().next().ok_or_else(|| {
1250 datafusion::error::DataFusionError::Execution(
1251 "Embedding service returned no results".to_string(),
1252 )
1253 })
1254}
1255
1256async fn execute_vector_query(
1257 graph_ctx: &GraphExecutionContext,
1258 args: &[Value],
1259 yield_items: &[(String, Option<String>)],
1260 target_properties: &HashMap<String, Vec<String>>,
1261 schema: &SchemaRef,
1262) -> DFResult<Option<RecordBatch>> {
1263 let label = require_string_arg(args, 0, "uni.vector.query: first argument (label)")?;
1264 let property = require_string_arg(args, 1, "uni.vector.query: second argument (property)")?;
1265
1266 let query_val = args.get(2).ok_or_else(|| {
1267 datafusion::error::DataFusionError::Execution(
1268 "uni.vector.query: third argument (query) is required".to_string(),
1269 )
1270 })?;
1271
1272 let storage = graph_ctx.storage();
1273
1274 let query_vector: Vec<f32> = if let Some(query_text) = query_val.as_str() {
1275 auto_embed_text(graph_ctx, &label, &property, query_text).await?
1276 } else {
1277 extract_vector(query_val)?
1278 };
1279
1280 let k = require_int_arg(args, 3, "uni.vector.query: fourth argument (k)")?;
1281 let filter = extract_optional_filter(args, 4);
1282 let threshold = extract_optional_threshold(args, 5);
1283 let query_ctx = graph_ctx.query_context();
1284
1285 let mut results = storage
1286 .vector_search(
1287 &label,
1288 &property,
1289 &query_vector,
1290 k,
1291 filter.as_deref(),
1292 Some(&query_ctx),
1293 )
1294 .await
1295 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1296
1297 if let Some(max_dist) = threshold {
1299 results.retain(|(_, dist)| *dist <= max_dist as f32);
1300 }
1301
1302 if results.is_empty() {
1303 return Ok(Some(create_empty_batch(schema.clone())?));
1304 }
1305
1306 let schema_manager = storage.schema_manager();
1308 let uni_schema = schema_manager.schema();
1309 let metric = uni_schema
1310 .vector_index_for_property(&label, &property)
1311 .map(|config| config.metric.clone())
1312 .unwrap_or(uni_common::core::schema::DistanceMetric::L2);
1313
1314 build_search_result_batch(
1315 &results,
1316 &label,
1317 &metric,
1318 yield_items,
1319 target_properties,
1320 graph_ctx,
1321 schema,
1322 )
1323 .await
1324}
1325
1326async fn execute_fts_query(
1331 graph_ctx: &GraphExecutionContext,
1332 args: &[Value],
1333 yield_items: &[(String, Option<String>)],
1334 target_properties: &HashMap<String, Vec<String>>,
1335 schema: &SchemaRef,
1336) -> DFResult<Option<RecordBatch>> {
1337 let label = require_string_arg(args, 0, "uni.fts.query: first argument (label)")?;
1338 let property = require_string_arg(args, 1, "uni.fts.query: second argument (property)")?;
1339 let search_term = require_string_arg(args, 2, "uni.fts.query: third argument (search_term)")?;
1340 let k = require_int_arg(args, 3, "uni.fts.query: fourth argument (k)")?;
1341 let filter = extract_optional_filter(args, 4);
1342 let threshold = extract_optional_threshold(args, 5);
1343
1344 let storage = graph_ctx.storage();
1345 let query_ctx = graph_ctx.query_context();
1346
1347 let mut results = storage
1348 .fts_search(
1349 &label,
1350 &property,
1351 &search_term,
1352 k,
1353 filter.as_deref(),
1354 Some(&query_ctx),
1355 )
1356 .await
1357 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1358
1359 if let Some(min_score) = threshold {
1360 results.retain(|(_, score)| *score as f64 >= min_score);
1361 }
1362
1363 if results.is_empty() {
1364 return Ok(Some(create_empty_batch(schema.clone())?));
1365 }
1366
1367 build_search_result_batch(
1370 &results,
1371 &label,
1372 &uni_common::core::schema::DistanceMetric::L2,
1373 yield_items,
1374 target_properties,
1375 graph_ctx,
1376 schema,
1377 )
1378 .await
1379}
1380
1381async fn execute_hybrid_search(
1386 graph_ctx: &GraphExecutionContext,
1387 args: &[Value],
1388 yield_items: &[(String, Option<String>)],
1389 target_properties: &HashMap<String, Vec<String>>,
1390 schema: &SchemaRef,
1391) -> DFResult<Option<RecordBatch>> {
1392 let label = require_string_arg(args, 0, "uni.search: first argument (label)")?;
1393
1394 let properties_val = args.get(1).ok_or_else(|| {
1396 datafusion::error::DataFusionError::Execution(
1397 "uni.search: second argument (properties) is required".to_string(),
1398 )
1399 })?;
1400
1401 let (vector_prop, fts_prop) = if let Some(obj) = properties_val.as_object() {
1402 let vec_prop = obj
1403 .get("vector")
1404 .and_then(|v| v.as_str())
1405 .map(|s| s.to_string());
1406 let fts_prop = obj
1407 .get("fts")
1408 .and_then(|v| v.as_str())
1409 .map(|s| s.to_string());
1410 (vec_prop, fts_prop)
1411 } else if let Some(prop) = properties_val.as_str() {
1412 (Some(prop.to_string()), Some(prop.to_string()))
1414 } else {
1415 return Err(datafusion::error::DataFusionError::Execution(
1416 "Properties must be an object {vector: '...', fts: '...'} or a string".to_string(),
1417 ));
1418 };
1419
1420 let query_text = require_string_arg(args, 2, "uni.search: third argument (query_text)")?;
1421
1422 let query_vector: Option<Vec<f32>> = args.get(3).and_then(|v| {
1424 if v.is_null() {
1425 return None;
1426 }
1427 v.as_array().map(|arr| {
1428 arr.iter()
1429 .filter_map(|v| v.as_f64().map(|f| f as f32))
1430 .collect()
1431 })
1432 });
1433
1434 let k = require_int_arg(args, 4, "uni.search: fifth argument (k)")?;
1435 let filter = extract_optional_filter(args, 5);
1436
1437 let options_val = args.get(6);
1439 let options_map = options_val.and_then(|v| v.as_object());
1440 let fusion_method = options_map
1441 .and_then(|m| m.get("method"))
1442 .and_then(|v| v.as_str())
1443 .unwrap_or("rrf")
1444 .to_string();
1445 let alpha = options_map
1446 .and_then(|m| m.get("alpha"))
1447 .and_then(|v| v.as_f64())
1448 .unwrap_or(0.5) as f32;
1449 let over_fetch_factor = options_map
1450 .and_then(|m| m.get("over_fetch"))
1451 .and_then(|v| v.as_f64())
1452 .unwrap_or(2.0) as f32;
1453 let rrf_k = options_map
1454 .and_then(|m| m.get("rrf_k"))
1455 .and_then(|v| v.as_u64())
1456 .unwrap_or(60) as usize;
1457
1458 let over_fetch_k = (k as f32 * over_fetch_factor).ceil() as usize;
1459
1460 let storage = graph_ctx.storage();
1461 let query_ctx = graph_ctx.query_context();
1462
1463 let mut vector_results: Vec<(Vid, f32)> = Vec::new();
1465 if let Some(ref vec_prop) = vector_prop {
1466 let qvec = if let Some(ref v) = query_vector {
1468 v.clone()
1469 } else {
1470 auto_embed_text(graph_ctx, &label, vec_prop, &query_text)
1472 .await
1473 .unwrap_or_default()
1474 };
1475
1476 if !qvec.is_empty() {
1477 vector_results = storage
1478 .vector_search(
1479 &label,
1480 vec_prop,
1481 &qvec,
1482 over_fetch_k,
1483 filter.as_deref(),
1484 Some(&query_ctx),
1485 )
1486 .await
1487 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1488 }
1489 }
1490
1491 let mut fts_results: Vec<(Vid, f32)> = Vec::new();
1493 if let Some(ref fts_prop) = fts_prop {
1494 fts_results = storage
1495 .fts_search(
1496 &label,
1497 fts_prop,
1498 &query_text,
1499 over_fetch_k,
1500 filter.as_deref(),
1501 Some(&query_ctx),
1502 )
1503 .await
1504 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?;
1505 }
1506
1507 let fused_results = match fusion_method.as_str() {
1509 "weighted" => fuse_weighted(&vector_results, &fts_results, alpha),
1510 _ => fuse_rrf(&vector_results, &fts_results, rrf_k),
1511 };
1512
1513 let final_results: Vec<_> = fused_results.into_iter().take(k).collect();
1515
1516 if final_results.is_empty() {
1517 return Ok(Some(create_empty_batch(schema.clone())?));
1518 }
1519
1520 let vec_score_map: HashMap<Vid, f32> = vector_results.iter().cloned().collect();
1522 let fts_score_map: HashMap<Vid, f32> = fts_results.iter().cloned().collect();
1523 let fts_max = fts_results.iter().map(|(_, s)| *s).fold(0.0f32, f32::max);
1524
1525 let uni_schema = storage.schema_manager().schema();
1527 let metric = vector_prop
1528 .as_ref()
1529 .and_then(|vp| {
1530 uni_schema
1531 .vector_index_for_property(&label, vp)
1532 .map(|config| config.metric.clone())
1533 })
1534 .unwrap_or(uni_common::core::schema::DistanceMetric::L2);
1535
1536 let score_ctx = HybridScoreContext {
1537 vec_score_map: &vec_score_map,
1538 fts_score_map: &fts_score_map,
1539 fts_max,
1540 metric: &metric,
1541 };
1542
1543 build_hybrid_search_batch(
1544 &final_results,
1545 &score_ctx,
1546 &label,
1547 yield_items,
1548 target_properties,
1549 graph_ctx,
1550 schema,
1551 )
1552 .await
1553}
1554
1555fn fuse_rrf(vec_results: &[(Vid, f32)], fts_results: &[(Vid, f32)], k: usize) -> Vec<(Vid, f32)> {
1558 let mut scores: HashMap<Vid, f32> = HashMap::new();
1559
1560 for (rank, (vid, _)) in vec_results.iter().enumerate() {
1561 let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
1562 *scores.entry(*vid).or_default() += rrf_score;
1563 }
1564
1565 for (rank, (vid, _)) in fts_results.iter().enumerate() {
1566 let rrf_score = 1.0 / (k as f32 + rank as f32 + 1.0);
1567 *scores.entry(*vid).or_default() += rrf_score;
1568 }
1569
1570 let mut results: Vec<_> = scores.into_iter().collect();
1571 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1572 results
1573}
1574
1575fn fuse_weighted(
1578 vec_results: &[(Vid, f32)],
1579 fts_results: &[(Vid, f32)],
1580 alpha: f32,
1581) -> Vec<(Vid, f32)> {
1582 let vec_max = vec_results.iter().map(|(_, s)| *s).fold(f32::MIN, f32::max);
1584 let vec_min = vec_results.iter().map(|(_, s)| *s).fold(f32::MAX, f32::min);
1585 let vec_range = if vec_max > vec_min {
1586 vec_max - vec_min
1587 } else {
1588 1.0
1589 };
1590
1591 let fts_max = fts_results.iter().map(|(_, s)| *s).fold(0.0f32, f32::max);
1592
1593 let vec_scores: HashMap<Vid, f32> = vec_results
1594 .iter()
1595 .map(|(vid, dist)| {
1596 let norm = 1.0 - (dist - vec_min) / vec_range;
1597 (*vid, norm)
1598 })
1599 .collect();
1600
1601 let fts_scores: HashMap<Vid, f32> = fts_results
1602 .iter()
1603 .map(|(vid, score)| {
1604 let norm = if fts_max > 0.0 { score / fts_max } else { 0.0 };
1605 (*vid, norm)
1606 })
1607 .collect();
1608
1609 let all_vids: std::collections::HashSet<Vid> = vec_scores
1610 .keys()
1611 .chain(fts_scores.keys())
1612 .cloned()
1613 .collect();
1614
1615 let mut results: Vec<(Vid, f32)> = all_vids
1616 .into_iter()
1617 .map(|vid| {
1618 let vec_score = *vec_scores.get(&vid).unwrap_or(&0.0);
1619 let fts_score = *fts_scores.get(&vid).unwrap_or(&0.0);
1620 let fused = alpha * vec_score + (1.0 - alpha) * fts_score;
1621 (vid, fused)
1622 })
1623 .collect();
1624
1625 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1626 results
1627}
1628
1629struct HybridScoreContext<'a> {
1631 vec_score_map: &'a HashMap<Vid, f32>,
1632 fts_score_map: &'a HashMap<Vid, f32>,
1633 fts_max: f32,
1634 metric: &'a uni_common::core::schema::DistanceMetric,
1635}
1636
1637async fn build_hybrid_search_batch(
1639 results: &[(Vid, f32)],
1640 scores: &HybridScoreContext<'_>,
1641 label: &str,
1642 yield_items: &[(String, Option<String>)],
1643 target_properties: &HashMap<String, Vec<String>>,
1644 graph_ctx: &GraphExecutionContext,
1645 schema: &SchemaRef,
1646) -> DFResult<Option<RecordBatch>> {
1647 let num_rows = results.len();
1648 let vids: Vec<Vid> = results.iter().map(|(vid, _)| *vid).collect();
1649 let fused_scores: Vec<f32> = results.iter().map(|(_, s)| *s).collect();
1650
1651 let property_manager = graph_ctx.property_manager();
1653 let query_ctx = graph_ctx.query_context();
1654 let uni_schema = graph_ctx.storage().schema_manager().schema();
1655 let label_props = uni_schema.properties.get(label);
1656
1657 let has_node_yield = yield_items
1658 .iter()
1659 .any(|(name, _)| map_yield_to_canonical(name) == "node");
1660
1661 let props_map = if has_node_yield {
1662 property_manager
1663 .get_batch_vertex_props_for_label(&vids, label, Some(&query_ctx))
1664 .await
1665 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?
1666 } else {
1667 HashMap::new()
1668 };
1669
1670 let mut columns: Vec<ArrayRef> = Vec::new();
1671
1672 for (name, alias) in yield_items {
1673 let output_name = alias.as_ref().unwrap_or(name);
1674 let canonical = map_yield_to_canonical(name);
1675
1676 match canonical.as_str() {
1677 "node" => {
1678 columns.extend(build_node_yield_columns(
1679 &vids,
1680 label,
1681 output_name,
1682 target_properties,
1683 &props_map,
1684 label_props,
1685 )?);
1686 }
1687 "vid" => {
1688 let mut builder = Int64Builder::with_capacity(num_rows);
1689 for vid in &vids {
1690 builder.append_value(vid.as_u64() as i64);
1691 }
1692 columns.push(Arc::new(builder.finish()));
1693 }
1694 "score" => {
1695 let mut builder = Float32Builder::with_capacity(num_rows);
1696 for score in &fused_scores {
1697 builder.append_value(*score);
1698 }
1699 columns.push(Arc::new(builder.finish()));
1700 }
1701 "vector_score" => {
1702 let mut builder = Float32Builder::with_capacity(num_rows);
1703 for vid in &vids {
1704 if let Some(&dist) = scores.vec_score_map.get(vid) {
1705 let score = calculate_score(dist, scores.metric);
1706 builder.append_value(score);
1707 } else {
1708 builder.append_null();
1709 }
1710 }
1711 columns.push(Arc::new(builder.finish()));
1712 }
1713 "fts_score" => {
1714 let mut builder = Float32Builder::with_capacity(num_rows);
1715 for vid in &vids {
1716 if let Some(&raw_score) = scores.fts_score_map.get(vid) {
1717 let norm = if scores.fts_max > 0.0 {
1718 raw_score / scores.fts_max
1719 } else {
1720 0.0
1721 };
1722 builder.append_value(norm);
1723 } else {
1724 builder.append_null();
1725 }
1726 }
1727 columns.push(Arc::new(builder.finish()));
1728 }
1729 "distance" => {
1730 let mut builder = Float64Builder::with_capacity(num_rows);
1732 for vid in &vids {
1733 if let Some(&dist) = scores.vec_score_map.get(vid) {
1734 builder.append_value(dist as f64);
1735 } else {
1736 builder.append_null();
1737 }
1738 }
1739 columns.push(Arc::new(builder.finish()));
1740 }
1741 _ => {
1742 let mut builder = StringBuilder::with_capacity(num_rows, 0);
1743 for _ in 0..num_rows {
1744 builder.append_null();
1745 }
1746 columns.push(Arc::new(builder.finish()));
1747 }
1748 }
1749 }
1750
1751 let batch = RecordBatch::try_new(schema.clone(), columns)
1752 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
1753 Ok(Some(batch))
1754}
1755
1756async fn build_search_result_batch(
1763 results: &[(Vid, f32)],
1764 label: &str,
1765 metric: &uni_common::core::schema::DistanceMetric,
1766 yield_items: &[(String, Option<String>)],
1767 target_properties: &HashMap<String, Vec<String>>,
1768 graph_ctx: &GraphExecutionContext,
1769 schema: &SchemaRef,
1770) -> DFResult<Option<RecordBatch>> {
1771 let num_rows = results.len();
1772 let vids: Vec<Vid> = results.iter().map(|(vid, _)| *vid).collect();
1773 let distances: Vec<f32> = results.iter().map(|(_, d)| *d).collect();
1774
1775 let scores: Vec<f32> = distances
1777 .iter()
1778 .map(|dist| calculate_score(*dist, metric))
1779 .collect();
1780
1781 let property_manager = graph_ctx.property_manager();
1783 let query_ctx = graph_ctx.query_context();
1784 let uni_schema = graph_ctx.storage().schema_manager().schema();
1785 let label_props = uni_schema.properties.get(label);
1786
1787 let has_node_yield = yield_items
1789 .iter()
1790 .any(|(name, _)| map_yield_to_canonical(name) == "node");
1791
1792 let props_map = if has_node_yield {
1793 property_manager
1794 .get_batch_vertex_props_for_label(&vids, label, Some(&query_ctx))
1795 .await
1796 .map_err(|e| datafusion::error::DataFusionError::Execution(e.to_string()))?
1797 } else {
1798 HashMap::new()
1799 };
1800
1801 let mut columns: Vec<ArrayRef> = Vec::new();
1803
1804 for (name, alias) in yield_items {
1805 let output_name = alias.as_ref().unwrap_or(name);
1806 let canonical = map_yield_to_canonical(name);
1807
1808 match canonical.as_str() {
1809 "node" => {
1810 columns.extend(build_node_yield_columns(
1811 &vids,
1812 label,
1813 output_name,
1814 target_properties,
1815 &props_map,
1816 label_props,
1817 )?);
1818 }
1819 "distance" => {
1820 let mut builder = Float64Builder::with_capacity(num_rows);
1821 for dist in &distances {
1822 builder.append_value(*dist as f64);
1823 }
1824 columns.push(Arc::new(builder.finish()));
1825 }
1826 "score" => {
1827 let mut builder = Float32Builder::with_capacity(num_rows);
1828 for score in &scores {
1829 builder.append_value(*score);
1830 }
1831 columns.push(Arc::new(builder.finish()));
1832 }
1833 "vid" => {
1834 let mut builder = Int64Builder::with_capacity(num_rows);
1835 for vid in &vids {
1836 builder.append_value(vid.as_u64() as i64);
1837 }
1838 columns.push(Arc::new(builder.finish()));
1839 }
1840 _ => {
1841 let mut builder = StringBuilder::with_capacity(num_rows, 0);
1843 for _ in 0..num_rows {
1844 builder.append_null();
1845 }
1846 columns.push(Arc::new(builder.finish()));
1847 }
1848 }
1849 }
1850
1851 let batch = RecordBatch::try_new(schema.clone(), columns)
1852 .map_err(|e| datafusion::error::DataFusionError::ArrowError(Box::new(e), None))?;
1853 Ok(Some(batch))
1854}
1855
1856fn build_node_yield_columns(
1863 vids: &[Vid],
1864 label: &str,
1865 output_name: &str,
1866 target_properties: &HashMap<String, Vec<String>>,
1867 props_map: &HashMap<Vid, uni_common::Properties>,
1868 label_props: Option<&std::collections::HashMap<String, uni_common::core::schema::PropertyMeta>>,
1869) -> DFResult<Vec<ArrayRef>> {
1870 let num_rows = vids.len();
1871 let mut columns = Vec::new();
1872
1873 let mut vid_builder = UInt64Builder::with_capacity(num_rows);
1875 for vid in vids {
1876 vid_builder.append_value(vid.as_u64());
1877 }
1878 columns.push(Arc::new(vid_builder.finish()) as ArrayRef);
1879
1880 let mut var_builder = StringBuilder::with_capacity(num_rows, num_rows * 20);
1882 for vid in vids {
1883 var_builder.append_value(vid.to_string());
1884 }
1885 columns.push(Arc::new(var_builder.finish()) as ArrayRef);
1886
1887 let mut labels_builder = arrow_array::builder::ListBuilder::new(StringBuilder::new());
1889 for _ in 0..num_rows {
1890 labels_builder.values().append_value(label);
1891 labels_builder.append(true);
1892 }
1893 columns.push(Arc::new(labels_builder.finish()) as ArrayRef);
1894
1895 if let Some(props) = target_properties.get(output_name) {
1897 for prop_name in props {
1898 let data_type = resolve_property_type(prop_name, label_props);
1899 let column = crate::query::df_graph::scan::build_property_column_static(
1900 vids, props_map, prop_name, &data_type,
1901 )?;
1902 columns.push(column);
1903 }
1904 }
1905
1906 Ok(columns)
1907}
1908
1909fn extract_vector(val: &Value) -> DFResult<Vec<f32>> {
1911 match val {
1912 Value::Vector(vec) => Ok(vec.clone()),
1913 Value::List(arr) => {
1914 let mut vec = Vec::with_capacity(arr.len());
1915 for v in arr {
1916 if let Some(f) = v.as_f64() {
1917 vec.push(f as f32);
1918 } else {
1919 return Err(datafusion::error::DataFusionError::Execution(
1920 "Query vector must contain numbers".to_string(),
1921 ));
1922 }
1923 }
1924 Ok(vec)
1925 }
1926 _ => Err(datafusion::error::DataFusionError::Execution(
1927 "Query vector must be a list or vector".to_string(),
1928 )),
1929 }
1930}
1931
1932fn calculate_score(distance: f32, metric: &uni_common::core::schema::DistanceMetric) -> f32 {
1934 match metric {
1935 uni_common::core::schema::DistanceMetric::Cosine => {
1936 (2.0 - distance) / 2.0
1938 }
1939 uni_common::core::schema::DistanceMetric::Dot => distance,
1940 _ => {
1941 1.0 / (1.0 + distance)
1943 }
1944 }
1945}