Skip to main content

uni_query/query/df_graph/
common.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! Common helpers shared across graph execution plan implementations.
5//!
6//! This module provides shared utilities to reduce code duplication across
7//! the df_graph module's execution plan implementations.
8
9use arrow_array::{ArrayRef, RecordBatch};
10use arrow_schema::{DataType, Field, Schema, SchemaRef};
11use datafusion::arrow::array::Array;
12use datafusion::common::Result as DFResult;
13use datafusion::physical_expr::{EquivalenceProperties, Partitioning};
14use datafusion::physical_plan::PlanProperties;
15use datafusion::prelude::SessionContext;
16use futures::TryStreamExt;
17use parking_lot::RwLock;
18use std::collections::HashMap;
19use std::sync::Arc;
20use uni_common::Value;
21use uni_common::core::schema::{DistanceMetric, Schema as UniSchema};
22use uni_cypher::ast::{BinaryOp, CypherLiteral, Expr};
23use uni_store::storage::manager::StorageManager;
24
25use super::GraphExecutionContext;
26use super::procedure_call::map_yield_to_canonical;
27use super::unwind::arrow_to_json_value;
28use crate::query::df_planner::HybridPhysicalPlanner;
29use crate::query::planner::LogicalPlan;
30
31/// Convert an `ArrowError` into a `DataFusionError`.
32///
33/// Wraps the Arrow-level error into DataFusion's `ArrowError` variant. Use this
34/// with `.map_err(arrow_err)` when calling raw Arrow compute kernels from
35/// DataFusion execution plans.
36pub fn arrow_err(e: arrow::error::ArrowError) -> datafusion::error::DataFusionError {
37    datafusion::error::DataFusionError::ArrowError(Box::new(e), None)
38}
39
40/// Compute standard plan properties for graph operators.
41///
42/// All graph operators use the same plan properties:
43/// - Unknown partitioning with 1 partition
44/// - Incremental emission type
45/// - Bounded execution
46pub fn compute_plan_properties(schema: SchemaRef) -> PlanProperties {
47    PlanProperties::new(
48        EquivalenceProperties::new(schema),
49        Partitioning::UnknownPartitioning(1),
50        datafusion::physical_plan::execution_plan::EmissionType::Incremental,
51        datafusion::physical_plan::execution_plan::Boundedness::Bounded,
52    )
53}
54
55/// Return the Arrow `DataType` for `_labels` columns: `List<Utf8>`.
56///
57/// This is used across scan, traverse, bind, and other modules whenever a
58/// `_labels` field needs to be declared in a schema. Centralizing the
59/// definition avoids divergence and reduces boilerplate.
60pub fn labels_data_type() -> DataType {
61    DataType::List(Arc::new(Field::new("item", DataType::Utf8, true)))
62}
63
64/// Extract a `UInt64Array` of vertex/edge IDs from an Arrow column.
65///
66/// Accepts both `UInt64` (native VID type) and `Int64` (from parameter
67/// injection where `arrow_to_json_value` round-trips through `Value::Int`).
68/// For `Int64` columns the values are cast to `UInt64`.
69///
70/// # Errors
71///
72/// Returns a `DataFusionError::Execution` if the column is neither `UInt64`
73/// nor `Int64`.
74pub fn column_as_vid_array(
75    col: &dyn arrow_array::Array,
76) -> datafusion::error::Result<std::borrow::Cow<'_, arrow_array::UInt64Array>> {
77    use arrow_array::{Int64Array, StructArray, UInt64Array};
78    use arrow_schema::DataType;
79
80    if let Some(arr) = col.as_any().downcast_ref::<UInt64Array>() {
81        return Ok(std::borrow::Cow::Borrowed(arr));
82    }
83
84    if let Some(arr) = col.as_any().downcast_ref::<Int64Array>() {
85        let cast: UInt64Array = arr.iter().map(|v| v.map(|i| i as u64)).collect();
86        return Ok(std::borrow::Cow::Owned(cast));
87    }
88
89    // Support entity-struct aliases (e.g., WITH coalesce(b, c) AS x) where
90    // traversal inputs may provide the source as a Struct with an "_vid" field.
91    if let Some(arr) = col.as_any().downcast_ref::<StructArray>()
92        && let DataType::Struct(fields) = arr.data_type()
93        && let Some((vid_idx, _)) = fields.find("_vid")
94    {
95        return column_as_vid_array(arr.column(vid_idx).as_ref());
96    }
97
98    // Support CypherValue-encoded Node values in LargeBinary columns
99    // (e.g., from list comprehension loop variables over node collections)
100    // Also handles JSON round-tripped nodes (Value::Map with _id field)
101    if let Some(arr) = col.as_any().downcast_ref::<arrow_array::LargeBinaryArray>() {
102        let vids = vids_from_large_binary(arr);
103        return Ok(std::borrow::Cow::Owned(vids));
104    }
105
106    // OPTIONAL MATCH can produce all-null columns with Arrow Null type
107    if *col.data_type() == DataType::Null {
108        let vids: UInt64Array = (0..col.len()).map(|_| None::<u64>).collect();
109        return Ok(std::borrow::Cow::Owned(vids));
110    }
111
112    Err(datafusion::error::DataFusionError::Execution(format!(
113        "VID column has type {:?}, expected UInt64 or Int64",
114        col.data_type()
115    )))
116}
117
118/// Extract a VID from a CypherValue.
119///
120/// Handles both `Value::Node` (native node) and `Value::Map` with `_id` field
121/// (JSON round-tripped node from `cv_array_to_large_list`).
122fn extract_vid_from_value(val: &Value) -> Option<u64> {
123    match val {
124        Value::Node(node) => Some(node.vid.as_u64()),
125        Value::Map(map) => {
126            // Handle round-tripped nodes that became Maps.
127            // Path nodes use struct fields (_vid, _label, properties) which
128            // round-trip through arrow_to_json_value as { "_vid": Int(N), ... }.
129            // Value::Node → serde_json uses { "_id": "N", ... }.
130            // Check both keys to handle either path.
131
132            // Check _vid first (from path struct → arrow_to_json_value round-trip)
133            if let Some(Value::Int(vid)) = map.get("_vid") {
134                return Some(*vid as u64);
135            }
136            // Also check _id (from Value::Node → serde_json round-trip)
137            if let Some(Value::String(id_str)) = map.get("_id") {
138                return id_str
139                    .strip_prefix("Vid(")
140                    .and_then(|s| s.strip_suffix(')'))
141                    .unwrap_or(id_str)
142                    .parse::<u64>()
143                    .ok();
144            }
145            if let Some(Value::Int(id)) = map.get("_id") {
146                return Some(*id as u64);
147            }
148            None
149        }
150        _ => None,
151    }
152}
153
154/// Extract VIDs from a `LargeBinaryArray` of CypherValue-encoded values.
155///
156/// Decodes each element and delegates to [`extract_vid_from_value`].
157/// Null elements and decode failures produce null VID entries.
158fn vids_from_large_binary(arr: &arrow_array::LargeBinaryArray) -> arrow_array::UInt64Array {
159    use uni_common::cypher_value_codec;
160
161    (0..arr.len())
162        .map(|i| {
163            if arr.is_null(i) {
164                return None;
165            }
166            cypher_value_codec::decode(arr.value(i))
167                .ok()
168                .as_ref()
169                .and_then(extract_vid_from_value)
170        })
171        .collect()
172}
173
174/// Extract VIDs from a column of CypherValue-encoded Node values.
175///
176/// Takes a `LargeBinary` array where each element is a CypherValue-encoded
177/// value and extracts VIDs from Node values. Non-Node values produce nulls.
178/// Also handles JSON round-tripped node Maps from `cv_array_to_large_list`.
179pub fn extract_vids_from_cypher_value_column(col: &dyn Array) -> DFResult<arrow_array::ArrayRef> {
180    let binary_col = col
181        .as_any()
182        .downcast_ref::<arrow_array::LargeBinaryArray>()
183        .ok_or_else(|| {
184            datafusion::error::DataFusionError::Execution(
185                "extract_vids_from_cypher_value_column: expected LargeBinary column".to_string(),
186            )
187        })?;
188    Ok(Arc::new(vids_from_large_binary(binary_col)) as arrow_array::ArrayRef)
189}
190
191/// Extract a typed value from a column at a given row index.
192///
193/// Looks up `col_name` in the batch schema, downcasts to `T`, and applies
194/// `extract_fn` if the value is valid. Returns `None` if the column is missing,
195/// the downcast fails, or the value is null.
196pub(crate) fn extract_column_value<T: arrow_array::Array + 'static, R>(
197    batch: &RecordBatch,
198    col_name: &str,
199    row_idx: usize,
200    extract_fn: impl FnOnce(&T, usize) -> R,
201) -> Option<R> {
202    let (idx, _) = batch.schema().column_with_name(col_name)?;
203    let col = batch.column(idx);
204    let arr = col.as_any().downcast_ref::<T>()?;
205    if arr.is_valid(row_idx) {
206        Some(extract_fn(arr, row_idx))
207    } else {
208        None
209    }
210}
211
212/// Build the standard node struct fields for path structures.
213///
214/// Used when materializing path objects containing nodes.
215/// Fields: `_vid`, `_labels`, `properties`
216pub fn node_struct_fields() -> arrow_schema::Fields {
217    arrow_schema::Fields::from(vec![
218        Field::new("_vid", DataType::UInt64, false),
219        Field::new("_labels", labels_data_type(), true),
220        Field::new("properties", DataType::LargeBinary, true),
221    ])
222}
223
224/// Build the standard edge struct fields for path structures.
225///
226/// Used when materializing path objects containing edges.
227/// Fields: `_eid`, `_type_name`, `_src`, `_dst`, `properties`
228pub fn edge_struct_fields() -> arrow_schema::Fields {
229    arrow_schema::Fields::from(vec![
230        Field::new("_eid", DataType::UInt64, false),
231        Field::new("_type_name", DataType::Utf8, false),
232        Field::new("_src", DataType::UInt64, false),
233        Field::new("_dst", DataType::UInt64, false),
234        Field::new("properties", DataType::LargeBinary, true),
235    ])
236}
237
238/// Encode a properties HashMap to CypherValue bytes for LargeBinary columns.
239///
240/// Used when materializing path properties that need to be stored in LargeBinary
241/// columns. Converts the HashMap into a `Value::Map` and encodes it using the
242/// CypherValue codec.
243pub fn encode_props_to_cv(props: &std::collections::HashMap<String, uni_common::Value>) -> Vec<u8> {
244    let val = uni_common::Value::Map(props.clone());
245    uni_common::cypher_value_codec::encode(&val)
246}
247
248/// Build edge list field for schema with given step variable name.
249///
250/// Creates a list of edge structs for the relationship variable in VLP patterns.
251/// For example, `r` in `MATCH (a)-[r*1..3]->(b)` gets a `List<EdgeStruct>`.
252pub fn build_edge_list_field(step_var: &str) -> Field {
253    let edge_item = Field::new("item", DataType::Struct(edge_struct_fields()), true);
254    // Field must be nullable to support OPTIONAL MATCH unmatched (r = NULL)
255    Field::new(step_var, DataType::List(Arc::new(edge_item)), true)
256}
257
258/// Build path struct field for schema with given path variable name.
259///
260/// Creates a struct field with `nodes` and `relationships` lists.
261pub fn build_path_struct_field(path_var: &str) -> Field {
262    let node_item = Field::new("item", DataType::Struct(node_struct_fields()), true);
263    let nodes_field = Field::new("nodes", DataType::List(Arc::new(node_item)), true);
264
265    let edge_item = Field::new("item", DataType::Struct(edge_struct_fields()), true);
266    let relationships_field =
267        Field::new("relationships", DataType::List(Arc::new(edge_item)), true);
268
269    Field::new(
270        path_var,
271        DataType::Struct(arrow_schema::Fields::from(vec![
272            nodes_field,
273            relationships_field,
274        ])),
275        true,
276    )
277}
278
279/// Extend an input schema with a path struct field.
280///
281/// Clones the fields from `input_schema` and appends a path struct field
282/// using [`build_path_struct_field`].
283pub fn extend_schema_with_path(input_schema: SchemaRef, path_variable: &str) -> SchemaRef {
284    let mut fields: Vec<Arc<Field>> = input_schema.fields().to_vec();
285    fields.push(Arc::new(build_path_struct_field(path_variable)));
286    Arc::new(Schema::new(fields))
287}
288
289/// Build a path struct array from nodes and relationships list arrays.
290///
291/// Combines the nodes and relationships arrays into a single `StructArray` with
292/// the standard path structure (`nodes`, `relationships`), applying the given
293/// validity mask.
294pub fn build_path_struct_array(
295    nodes_array: ArrayRef,
296    rels_array: ArrayRef,
297    path_validity: Vec<bool>,
298) -> DFResult<arrow_array::StructArray> {
299    Ok(arrow_array::StructArray::try_new(
300        arrow_schema::Fields::from(vec![
301            Arc::new(Field::new("nodes", nodes_array.data_type().clone(), true)),
302            Arc::new(Field::new(
303                "relationships",
304                rels_array.data_type().clone(),
305                true,
306            )),
307        ]),
308        vec![nodes_array, rels_array],
309        Some(arrow::buffer::NullBuffer::from(path_validity)),
310    )?)
311}
312
313/// Create a `ListBuilder<StructBuilder>` for building edge list arrays.
314///
315/// Used when materializing edge lists for step variables (`r` in `[r*1..3]`)
316/// and path relationship arrays. Returns a builder whose struct fields match
317/// `edge_struct_fields()`.
318pub fn new_edge_list_builder()
319-> arrow_array::builder::ListBuilder<arrow_array::builder::StructBuilder> {
320    use arrow_array::builder::{LargeBinaryBuilder, StringBuilder, StructBuilder, UInt64Builder};
321    arrow_array::builder::ListBuilder::new(StructBuilder::new(
322        edge_struct_fields(),
323        vec![
324            Box::new(UInt64Builder::new()),
325            Box::new(StringBuilder::new()),
326            Box::new(UInt64Builder::new()),
327            Box::new(UInt64Builder::new()),
328            Box::new(LargeBinaryBuilder::new()),
329        ],
330    ))
331}
332
333/// Create a `ListBuilder<StructBuilder>` for building node list arrays.
334///
335/// Used when materializing path node arrays. Returns a builder whose struct
336/// fields match `node_struct_fields()`.
337pub fn new_node_list_builder()
338-> arrow_array::builder::ListBuilder<arrow_array::builder::StructBuilder> {
339    use arrow_array::builder::{
340        LargeBinaryBuilder, ListBuilder, StringBuilder, StructBuilder, UInt64Builder,
341    };
342    arrow_array::builder::ListBuilder::new(StructBuilder::new(
343        node_struct_fields(),
344        vec![
345            Box::new(UInt64Builder::new()),
346            Box::new(ListBuilder::new(StringBuilder::new())),
347            Box::new(LargeBinaryBuilder::new()),
348        ],
349    ))
350}
351
352/// Append a single edge to an edge struct builder.
353///
354/// Writes `_eid`, `_type_name`, `_src`, `_dst`, and `properties` fields,
355/// then appends the struct row. The `query_ctx` is used to look up edge
356/// properties from the L0 visibility chain.
357pub fn append_edge_to_struct(
358    struct_builder: &mut arrow_array::builder::StructBuilder,
359    eid: uni_common::core::id::Eid,
360    type_name: &str,
361    src_vid: u64,
362    dst_vid: u64,
363    query_ctx: &uni_store::runtime::context::QueryContext,
364) {
365    use arrow_array::builder::{LargeBinaryBuilder, StringBuilder, UInt64Builder};
366    use uni_store::runtime::l0_visibility;
367
368    struct_builder
369        .field_builder::<UInt64Builder>(0)
370        .unwrap()
371        .append_value(eid.as_u64());
372    struct_builder
373        .field_builder::<StringBuilder>(1)
374        .unwrap()
375        .append_value(type_name);
376    struct_builder
377        .field_builder::<UInt64Builder>(2)
378        .unwrap()
379        .append_value(src_vid);
380    struct_builder
381        .field_builder::<UInt64Builder>(3)
382        .unwrap()
383        .append_value(dst_vid);
384    let props_builder = struct_builder
385        .field_builder::<LargeBinaryBuilder>(4)
386        .unwrap();
387    if let Some(props) = l0_visibility::get_edge_properties(eid, query_ctx) {
388        let cv_bytes = encode_props_to_cv(&props);
389        props_builder.append_value(&cv_bytes);
390    } else {
391        props_builder.append_null();
392    }
393    struct_builder.append(true);
394}
395
396/// Append a null edge struct row (placeholder values + null validity).
397///
398/// Arrow struct builders require all field builders to advance even for null rows.
399/// This appends default placeholder values and marks the struct row as null.
400fn append_null_edge_struct(struct_builder: &mut arrow_array::builder::StructBuilder) {
401    use arrow_array::builder::{LargeBinaryBuilder, StringBuilder, UInt64Builder};
402
403    struct_builder
404        .field_builder::<UInt64Builder>(0)
405        .unwrap()
406        .append_value(0);
407    struct_builder
408        .field_builder::<StringBuilder>(1)
409        .unwrap()
410        .append_value("");
411    struct_builder
412        .field_builder::<UInt64Builder>(2)
413        .unwrap()
414        .append_value(0);
415    struct_builder
416        .field_builder::<UInt64Builder>(3)
417        .unwrap()
418        .append_value(0);
419    struct_builder
420        .field_builder::<LargeBinaryBuilder>(4)
421        .unwrap()
422        .append_null();
423    struct_builder.append(false);
424}
425
426/// Append an edge to a struct builder, handling the `Option<Eid>` case.
427///
428/// When `eid` is `Some`, resolves the type name from `batch_type_name` (primary)
429/// or L0 visibility (fallback), then delegates to [`append_edge_to_struct`].
430/// When `eid` is `None`, appends a null struct row.
431pub fn append_edge_to_struct_optional(
432    struct_builder: &mut arrow_array::builder::StructBuilder,
433    eid: Option<uni_common::core::id::Eid>,
434    src_vid: u64,
435    dst_vid: u64,
436    batch_type_name: Option<String>,
437    query_ctx: &uni_store::runtime::context::QueryContext,
438) {
439    match eid {
440        Some(e) => {
441            use uni_store::runtime::l0_visibility;
442            let type_name = batch_type_name
443                .or_else(|| l0_visibility::get_edge_type(e, query_ctx))
444                .unwrap_or_default();
445            append_edge_to_struct(struct_builder, e, &type_name, src_vid, dst_vid, query_ctx);
446        }
447        None => append_null_edge_struct(struct_builder),
448    }
449}
450
451/// Append a single node to a node struct builder.
452///
453/// Writes `_vid`, `_labels`, and `properties` fields, then appends the struct
454/// row. The `query_ctx` is used to look up labels and properties from the L0
455/// visibility chain.
456pub fn append_node_to_struct(
457    struct_builder: &mut arrow_array::builder::StructBuilder,
458    vid: uni_common::core::id::Vid,
459    query_ctx: &uni_store::runtime::context::QueryContext,
460) {
461    use arrow_array::builder::{LargeBinaryBuilder, ListBuilder, StringBuilder, UInt64Builder};
462    use uni_store::runtime::l0_visibility;
463
464    struct_builder
465        .field_builder::<UInt64Builder>(0)
466        .unwrap()
467        .append_value(vid.as_u64());
468    let labels = l0_visibility::get_vertex_labels(vid, query_ctx);
469    let labels_builder = struct_builder
470        .field_builder::<ListBuilder<StringBuilder>>(1)
471        .unwrap();
472    let values = labels_builder.values();
473    for lbl in &labels {
474        values.append_value(lbl);
475    }
476    labels_builder.append(true);
477    let props_builder = struct_builder
478        .field_builder::<LargeBinaryBuilder>(2)
479        .unwrap();
480    if let Some(props) = l0_visibility::get_vertex_properties(vid, query_ctx) {
481        let cv_bytes = encode_props_to_cv(&props);
482        props_builder.append_value(&cv_bytes);
483    } else {
484        props_builder.append_null();
485    }
486    struct_builder.append(true);
487}
488
489/// Append a null node struct row (placeholder values + null validity).
490///
491/// Arrow struct builders require all field builders to advance even for null rows.
492/// This appends default placeholder values and marks the struct row as null.
493fn append_null_node_struct(struct_builder: &mut arrow_array::builder::StructBuilder) {
494    use arrow_array::builder::{LargeBinaryBuilder, ListBuilder, StringBuilder, UInt64Builder};
495
496    struct_builder
497        .field_builder::<UInt64Builder>(0)
498        .unwrap()
499        .append_value(0);
500    struct_builder
501        .field_builder::<ListBuilder<StringBuilder>>(1)
502        .unwrap()
503        .append(true);
504    struct_builder
505        .field_builder::<LargeBinaryBuilder>(2)
506        .unwrap()
507        .append_null();
508    struct_builder.append(false);
509}
510
511/// Append a node to a struct builder, handling the `Option<Vid>` case.
512///
513/// When `vid` is `Some`, delegates to [`append_node_to_struct`].
514/// When `vid` is `None`, appends a null struct row.
515pub fn append_node_to_struct_optional(
516    struct_builder: &mut arrow_array::builder::StructBuilder,
517    vid: Option<uni_common::core::id::Vid>,
518    query_ctx: &uni_store::runtime::context::QueryContext,
519) {
520    match vid {
521        Some(v) => append_node_to_struct(struct_builder, v, query_ctx),
522        None => append_null_node_struct(struct_builder),
523    }
524}
525
526/// Re-encode a `LargeListArray` of CypherValue elements into a `LargeBinaryArray` of CypherValue arrays.
527///
528/// Each row in the input `LargeListArray` contains zero or more `LargeBinary`
529/// elements that are individually CypherValue-encoded values. This function decodes
530/// each element, wraps them into a `serde_json::Value::Array`, and re-encodes
531/// the whole array as a single CypherValue blob in the output `LargeBinaryArray`.
532///
533/// Null rows in the input produce null entries in the output.
534///
535/// # Errors
536///
537/// Returns a `DataFusionError::Execution` if the input is not a
538/// `LargeListArray` or if CypherValue decoding fails.
539pub fn large_list_of_cv_to_cv_array(
540    list: &datafusion::arrow::array::LargeListArray,
541) -> datafusion::error::Result<Arc<dyn datafusion::arrow::array::Array>> {
542    use datafusion::arrow::array::{LargeBinaryArray, LargeBinaryBuilder};
543
544    let values = list.values();
545    let binary_values = values
546        .as_any()
547        .downcast_ref::<LargeBinaryArray>()
548        .ok_or_else(|| {
549            datafusion::error::DataFusionError::Execution(
550                "large_list_of_cv_to_cv_array: inner values must be LargeBinaryArray".to_string(),
551            )
552        })?;
553
554    let mut builder = LargeBinaryBuilder::new();
555
556    for row_idx in 0..list.len() {
557        if list.is_null(row_idx) {
558            builder.append_null();
559            continue;
560        }
561
562        let start = list.offsets()[row_idx] as usize;
563        let end = list.offsets()[row_idx + 1] as usize;
564
565        let mut json_elements = Vec::with_capacity(end - start);
566        for elem_idx in start..end {
567            if binary_values.is_null(elem_idx) {
568                json_elements.push(serde_json::Value::Null);
569            } else {
570                let blob = binary_values.value(elem_idx);
571                match uni_common::cypher_value_codec::decode(blob) {
572                    Ok(uni_val) => {
573                        let json_val: serde_json::Value = uni_val.into();
574                        json_elements.push(json_val);
575                    }
576                    Err(_) => json_elements.push(serde_json::Value::Null),
577                }
578            }
579        }
580
581        let uni_val: uni_common::Value = serde_json::Value::Array(json_elements).into();
582        let bytes = uni_common::cypher_value_codec::encode(&uni_val);
583        builder.append_value(&bytes);
584    }
585
586    Ok(Arc::new(builder.finish()))
587}
588
589/// Convert a single Arrow array element at `idx` to `serde_json::Value`.
590///
591/// Handles the common scalar types (UInt64, Int64, Float64, Utf8, Boolean, LargeBinary).
592/// Returns `serde_json::Value::Null` for null values or unsupported types.
593fn arrow_element_to_json(
594    col: &dyn datafusion::arrow::array::Array,
595    idx: usize,
596) -> serde_json::Value {
597    use datafusion::arrow::array::{
598        BooleanArray, Float64Array, Int64Array, StringArray, UInt64Array,
599    };
600
601    if col.is_null(idx) {
602        return serde_json::Value::Null;
603    }
604
605    if let Some(arr) = col.as_any().downcast_ref::<UInt64Array>() {
606        serde_json::Value::Number(serde_json::Number::from(arr.value(idx)))
607    } else if let Some(arr) = col.as_any().downcast_ref::<Int64Array>() {
608        serde_json::Value::Number(serde_json::Number::from(arr.value(idx)))
609    } else if let Some(arr) = col.as_any().downcast_ref::<Float64Array>() {
610        serde_json::Number::from_f64(arr.value(idx))
611            .map(serde_json::Value::Number)
612            .unwrap_or(serde_json::Value::Null)
613    } else if let Some(arr) = col.as_any().downcast_ref::<StringArray>() {
614        serde_json::Value::String(arr.value(idx).to_string())
615    } else if let Some(arr) = col.as_any().downcast_ref::<BooleanArray>() {
616        serde_json::Value::Bool(arr.value(idx))
617    } else if let Some(arr) = col.as_any().downcast_ref::<arrow_array::LargeBinaryArray>() {
618        uni_common::cypher_value_codec::decode(arr.value(idx))
619            .map(|v| v.into())
620            .unwrap_or(serde_json::Value::Null)
621    } else {
622        serde_json::Value::Null
623    }
624}
625
626/// Convert a typed `LargeListArray` to a `LargeBinaryArray` of CypherValue arrays.
627///
628/// Each row in the input `LargeListArray` contains zero or more elements of a
629/// specific type (Int64, Float64, Utf8, Boolean, or nested LargeBinary). This
630/// function converts each row into a JSON array and encodes it as a CypherValue blob.
631///
632/// If the inner type is already `LargeBinary` (CypherValue), delegates to
633/// `large_list_of_cv_to_cv_array()`.
634///
635/// Null rows in the input produce null entries in the output.
636///
637/// # Errors
638///
639/// Returns a `DataFusionError::Execution` if CypherValue encoding fails.
640pub fn typed_large_list_to_cv_array(
641    list: &datafusion::arrow::array::LargeListArray,
642) -> datafusion::error::Result<Arc<dyn datafusion::arrow::array::Array>> {
643    use datafusion::arrow::array::{LargeBinaryBuilder, StructArray};
644
645    let values = list.values();
646
647    // If inner type is LargeBinary, delegate to existing function
648    if values.data_type() == &DataType::LargeBinary {
649        return large_list_of_cv_to_cv_array(list);
650    }
651
652    // Build the element-to-JSON converter closure. For Struct arrays, we need
653    // to iterate over fields; for scalar arrays, use arrow_element_to_json directly.
654    let elem_to_json: Box<dyn Fn(usize) -> serde_json::Value> = match values.data_type() {
655        DataType::UInt64
656        | DataType::Int64
657        | DataType::Float64
658        | DataType::Utf8
659        | DataType::Boolean => {
660            let values = values.clone();
661            Box::new(move |idx| arrow_element_to_json(values.as_ref(), idx))
662        }
663        DataType::Struct(_) => {
664            let typed = values
665                .as_any()
666                .downcast_ref::<StructArray>()
667                .ok_or_else(|| {
668                    datafusion::error::DataFusionError::Execution(
669                        "Expected StructArray".to_string(),
670                    )
671                })?;
672            let fields: Vec<_> = typed.fields().iter().cloned().collect();
673            let columns: Vec<_> = (0..typed.num_columns())
674                .map(|i| typed.column(i).clone())
675                .collect();
676            let nulls = typed.nulls().cloned();
677            Box::new(move |idx| {
678                if nulls.as_ref().is_some_and(|n| n.is_null(idx)) {
679                    return serde_json::Value::Null;
680                }
681                let mut map = serde_json::Map::new();
682                for (field_idx, field) in fields.iter().enumerate() {
683                    let value = arrow_element_to_json(columns[field_idx].as_ref(), idx);
684                    map.insert(field.name().clone(), value);
685                }
686                serde_json::Value::Object(map)
687            })
688        }
689        other => {
690            return Err(datafusion::error::DataFusionError::Execution(format!(
691                "Unsupported element type for typed_large_list_to_cv_array: {:?}",
692                other
693            )));
694        }
695    };
696
697    let mut builder = LargeBinaryBuilder::new();
698
699    for row_idx in 0..list.len() {
700        if list.is_null(row_idx) {
701            builder.append_null();
702            continue;
703        }
704
705        let start = list.offsets()[row_idx] as usize;
706        let end = list.offsets()[row_idx + 1] as usize;
707        let json_elements: Vec<serde_json::Value> = (start..end).map(&elem_to_json).collect();
708
709        let uni_val: uni_common::Value = serde_json::Value::Array(json_elements).into();
710        let bytes = uni_common::cypher_value_codec::encode(&uni_val);
711        builder.append_value(&bytes);
712    }
713
714    Ok(Arc::new(builder.finish()))
715}
716
717/// Convert a `LargeBinaryArray` of CypherValue-encoded arrays into a `LargeListArray`.
718///
719/// Each element in the input array is a CypherValue blob encoding a JSON array (e.g. `[1,2,3]`).
720/// Elements are converted to the specified `element_type`. For example, if `element_type`
721/// is `Int64`, CypherValue numbers are parsed as i64 values.
722///
723/// Non-array CypherValue values and nulls produce empty lists.
724pub fn cv_array_to_large_list(
725    array: &dyn datafusion::arrow::array::Array,
726    element_type: &DataType,
727) -> datafusion::error::Result<Arc<dyn datafusion::arrow::array::Array>> {
728    use datafusion::arrow::array::LargeBinaryArray;
729    use datafusion::arrow::buffer::{OffsetBuffer, ScalarBuffer};
730
731    let binary_arr = array
732        .as_any()
733        .downcast_ref::<LargeBinaryArray>()
734        .ok_or_else(|| {
735            datafusion::error::DataFusionError::Execution(
736                "cv_array_to_large_list: expected LargeBinaryArray".to_string(),
737            )
738        })?;
739
740    // Collect all JSON elements across all rows
741    let num_rows = binary_arr.len();
742    let mut all_elements: Vec<Vec<serde_json::Value>> = Vec::with_capacity(num_rows);
743    let mut nulls = Vec::with_capacity(num_rows);
744
745    for i in 0..num_rows {
746        if binary_arr.is_null(i) {
747            all_elements.push(Vec::new());
748            nulls.push(false);
749            continue;
750        }
751
752        let blob = binary_arr.value(i);
753        let uni_val = match uni_common::cypher_value_codec::decode(blob) {
754            Ok(v) => v,
755            Err(_) => {
756                all_elements.push(Vec::new());
757                nulls.push(false);
758                continue;
759            }
760        };
761        let json_val_decoded: serde_json::Value = uni_val.into();
762
763        match json_val_decoded {
764            serde_json::Value::Array(elements) => {
765                all_elements.push(elements);
766                nulls.push(true);
767            }
768            _ => {
769                all_elements.push(Vec::new());
770                nulls.push(true);
771            }
772        }
773    }
774
775    // Build typed values array and offsets
776    let mut offsets: Vec<i64> = Vec::with_capacity(num_rows + 1);
777    offsets.push(0);
778
779    let values_array: Arc<dyn datafusion::arrow::array::Array> = match element_type {
780        DataType::Int64 => {
781            let mut builder = datafusion::arrow::array::builder::Int64Builder::new();
782            for elems in &all_elements {
783                for elem in elems {
784                    if let serde_json::Value::Number(n) = elem {
785                        if let Some(i) = n.as_i64() {
786                            builder.append_value(i);
787                        } else if let Some(f) = n.as_f64() {
788                            builder.append_value(f as i64);
789                        } else {
790                            builder.append_null();
791                        }
792                    } else {
793                        builder.append_null();
794                    }
795                }
796                offsets.push(offsets.last().unwrap() + elems.len() as i64);
797            }
798            Arc::new(builder.finish())
799        }
800        DataType::Float64 => {
801            let mut builder = datafusion::arrow::array::builder::Float64Builder::new();
802            for elems in &all_elements {
803                for elem in elems {
804                    if let serde_json::Value::Number(n) = elem
805                        && let Some(f) = n.as_f64()
806                    {
807                        builder.append_value(f);
808                    } else {
809                        builder.append_null();
810                    }
811                }
812                offsets.push(offsets.last().unwrap() + elems.len() as i64);
813            }
814            Arc::new(builder.finish())
815        }
816        DataType::Utf8 | DataType::LargeUtf8 => {
817            let mut builder = datafusion::arrow::array::builder::StringBuilder::new();
818            for elems in &all_elements {
819                for elem in elems {
820                    match elem {
821                        serde_json::Value::String(s) => builder.append_value(s),
822                        serde_json::Value::Null => builder.append_null(),
823                        other => builder.append_value(other.to_string()),
824                    }
825                }
826                offsets.push(offsets.last().unwrap() + elems.len() as i64);
827            }
828            Arc::new(builder.finish())
829        }
830        DataType::Boolean => {
831            let mut builder = datafusion::arrow::array::builder::BooleanBuilder::new();
832            for elems in &all_elements {
833                for elem in elems {
834                    if let serde_json::Value::Bool(b) = elem {
835                        builder.append_value(*b);
836                    } else {
837                        builder.append_null();
838                    }
839                }
840                offsets.push(offsets.last().unwrap() + elems.len() as i64);
841            }
842            Arc::new(builder.finish())
843        }
844        // Fallback: keep as CypherValue LargeBinary blobs
845        _ => {
846            let mut builder = datafusion::arrow::array::builder::LargeBinaryBuilder::new();
847            for elems in &all_elements {
848                for elem in elems {
849                    let uni_val: uni_common::Value = elem.clone().into();
850                    let bytes = uni_common::cypher_value_codec::encode(&uni_val);
851                    builder.append_value(&bytes);
852                }
853                offsets.push(offsets.last().unwrap() + elems.len() as i64);
854            }
855            Arc::new(builder.finish())
856        }
857    };
858
859    let field = Arc::new(Field::new("item", element_type.clone(), true));
860    let offset_buffer = OffsetBuffer::new(ScalarBuffer::from(offsets));
861    let null_buffer = datafusion::arrow::buffer::NullBuffer::from(nulls);
862
863    let large_list = datafusion::arrow::array::LargeListArray::new(
864        field,
865        offset_buffer,
866        values_array,
867        Some(null_buffer),
868    );
869
870    Ok(Arc::new(large_list))
871}
872
873/// Collect all record batches from all partitions of an execution plan.
874///
875/// Iterates over each partition, executes it, and collects all resulting
876/// batches into a single `Vec`. Shared by `execute_subplan` and `run_apply`.
877pub async fn collect_all_partitions(
878    plan: &Arc<dyn datafusion::physical_plan::ExecutionPlan>,
879    task_ctx: Arc<datafusion::execution::TaskContext>,
880) -> DFResult<Vec<RecordBatch>> {
881    let partition_count = plan.properties().output_partitioning().partition_count();
882
883    let mut all_batches = Vec::new();
884    for partition in 0..partition_count {
885        let stream = plan.execute(partition, task_ctx.clone())?;
886        let batches: Vec<RecordBatch> = stream.try_collect().await?;
887        all_batches.extend(batches);
888    }
889    Ok(all_batches)
890}
891
892/// Execute a logical plan using a fresh HybridPhysicalPlanner with the given params.
893///
894/// Shared by `RecursiveCTEExec`, `GraphApplyExec`, and `ExistsExecExpr`.
895pub async fn execute_subplan(
896    plan: &LogicalPlan,
897    params: &HashMap<String, Value>,
898    outer_values: &HashMap<String, Value>,
899    graph_ctx: &Arc<GraphExecutionContext>,
900    session_ctx: &Arc<RwLock<SessionContext>>,
901    storage: &Arc<StorageManager>,
902    schema_info: &Arc<UniSchema>,
903) -> DFResult<Vec<RecordBatch>> {
904    let mut planner = HybridPhysicalPlanner::with_l0_context(
905        session_ctx.clone(),
906        storage.clone(),
907        graph_ctx.l0_context().clone(),
908        graph_ctx.property_manager().clone(),
909        schema_info.clone(),
910        params.clone(),
911        outer_values.clone(),
912    );
913
914    // Propagate registries from parent context so procedures remain available
915    // inside correlated subqueries (Apply operator).
916    if let Some(registry) = graph_ctx.algo_registry() {
917        planner = planner.with_algo_registry(registry.clone());
918    }
919    if let Some(registry) = graph_ctx.procedure_registry() {
920        planner = planner.with_procedure_registry(registry.clone());
921    }
922    if let Some(runtime) = graph_ctx.xervo_runtime() {
923        planner = planner.with_xervo_runtime(runtime.clone());
924    }
925
926    let execution_plan = planner.plan(plan).map_err(|e| {
927        datafusion::error::DataFusionError::Execution(format!("Sub-plan error: {e}"))
928    })?;
929
930    let task_ctx = session_ctx.read().task_ctx();
931    let all_batches = collect_all_partitions(&execution_plan, task_ctx).await?;
932
933    Ok(all_batches)
934}
935
936/// Extract a single row from a RecordBatch as a HashMap of column name → Value.
937///
938/// Used to build parameters for correlated subqueries (Apply, EXISTS).
939pub fn extract_row_params(batch: &RecordBatch, row_idx: usize) -> HashMap<String, Value> {
940    let schema = batch.schema();
941    (0..batch.num_columns())
942        .map(|col_idx| {
943            let col_name = schema.field(col_idx).name().clone();
944            let val = arrow_to_json_value(batch.column(col_idx).as_ref(), row_idx);
945            (col_name, val)
946        })
947        .collect()
948}
949
950/// Infer the output schema of a ProcedureCall logical plan node.
951///
952/// This is a simplified version of `GraphProcedureCallExec::build_schema()` that
953/// doesn't require target_properties or graph_ctx. It covers common procedure types
954/// with basic scalar type inference. For unknown procedures or complex node expansions,
955/// it falls back to Utf8.
956fn infer_procedure_call_schema(
957    procedure_name: &str,
958    yield_items: &[(String, Option<String>)],
959    _schema_info: &UniSchema,
960) -> SchemaRef {
961    let infer_type = |name: &str| -> DataType {
962        match procedure_name {
963            "uni.schema.labels" => match name {
964                "propertyCount" | "nodeCount" | "indexCount" => DataType::Int64,
965                _ => DataType::Utf8,
966            },
967            "uni.schema.edgeTypes" | "uni.schema.relationshipTypes" => match name {
968                "propertyCount" => DataType::Int64,
969                _ => DataType::Utf8,
970            },
971            "uni.schema.constraints" => match name {
972                "enabled" => DataType::Boolean,
973                _ => DataType::Utf8,
974            },
975            "uni.schema.labelInfo" => match name {
976                "nullable" | "indexed" | "unique" => DataType::Boolean,
977                _ => DataType::Utf8,
978            },
979            "uni.vector.query" | "uni.fts.query" | "uni.search" => {
980                // Search procedures: infer types via canonical yield mapping.
981                // Node expansion happens at execution time in GraphProcedureCallExec.
982                match map_yield_to_canonical(name).as_str() {
983                    "distance" => DataType::Float64,
984                    "score" | "vector_score" | "fts_score" | "raw_score" => DataType::Float32,
985                    "vid" => DataType::Int64,
986                    _ => DataType::Utf8,
987                }
988            }
989            // uni.schema.indexes, unknown procedures, and fallback: all Utf8
990            _ => DataType::Utf8,
991        }
992    };
993
994    let fields: Vec<Field> = yield_items
995        .iter()
996        .map(|(name, alias)| {
997            let col_name = alias.as_ref().unwrap_or(name);
998            Field::new(col_name, infer_type(name), true)
999        })
1000        .collect();
1001
1002    Arc::new(Schema::new(fields))
1003}
1004
1005/// Infer the output schema of a logical plan using UniSchema property metadata.
1006///
1007/// This is needed because correlated subqueries reference outer variables that
1008/// don't exist as physical columns at planning time, so we can't dry-run plan
1009/// the subquery to get its schema. Instead we walk the logical plan and use
1010/// `UniSchema` property metadata to infer types.
1011pub fn infer_logical_plan_schema(plan: &LogicalPlan, schema_info: &UniSchema) -> SchemaRef {
1012    // Walk to outermost Project
1013    if let LogicalPlan::Project { projections, .. } = plan {
1014        let fields: Vec<Field> = projections
1015            .iter()
1016            .map(|(expr, alias)| {
1017                let name = alias.clone().unwrap_or_else(|| expr.to_string_repr());
1018                let dt = infer_expr_type(expr, schema_info);
1019                Field::new(name, dt, true)
1020            })
1021            .collect();
1022        return Arc::new(Schema::new(fields));
1023    }
1024
1025    // For non-Project plans, walk through wrapping nodes
1026    match plan {
1027        LogicalPlan::Sort { input, .. }
1028        | LogicalPlan::Limit { input, .. }
1029        | LogicalPlan::Filter { input, .. }
1030        | LogicalPlan::Distinct { input } => infer_logical_plan_schema(input, schema_info),
1031
1032        LogicalPlan::ProcedureCall {
1033            procedure_name,
1034            yield_items,
1035            ..
1036        } => infer_procedure_call_schema(procedure_name, yield_items, schema_info),
1037
1038        _ => {
1039            // Fallback: empty schema
1040            Arc::new(Schema::empty())
1041        }
1042    }
1043}
1044
1045/// Infer Arrow DataType for a Cypher expression using schema metadata.
1046fn infer_expr_type(expr: &Expr, schema_info: &UniSchema) -> DataType {
1047    match expr {
1048        Expr::Property(base, key) => {
1049            if let Expr::Variable(_) = base.as_ref() {
1050                // Look up key across all labels/edge types in schema
1051                for props in schema_info.properties.values() {
1052                    if let Some(meta) = props.get(key.as_str()) {
1053                        return meta.r#type.to_arrow();
1054                    }
1055                }
1056                DataType::LargeBinary
1057            } else {
1058                DataType::LargeBinary
1059            }
1060        }
1061        Expr::BinaryOp { left, op, right } => match op {
1062            BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul | BinaryOp::Div | BinaryOp::Mod => {
1063                let lt = infer_expr_type(left, schema_info);
1064                let rt = infer_expr_type(right, schema_info);
1065                numeric_promotion(&lt, &rt)
1066            }
1067            BinaryOp::Eq
1068            | BinaryOp::NotEq
1069            | BinaryOp::Lt
1070            | BinaryOp::LtEq
1071            | BinaryOp::Gt
1072            | BinaryOp::GtEq
1073            | BinaryOp::And
1074            | BinaryOp::Or => DataType::Boolean,
1075            _ => DataType::LargeBinary,
1076        },
1077        Expr::Literal(lit) => match lit {
1078            CypherLiteral::Integer(_) => DataType::Int64,
1079            CypherLiteral::Float(_) => DataType::Float64,
1080            CypherLiteral::String(_) => DataType::Utf8,
1081            CypherLiteral::Bool(_) => DataType::Boolean,
1082            CypherLiteral::Null => DataType::Null,
1083            CypherLiteral::Bytes(_) => DataType::LargeBinary,
1084        },
1085        Expr::Variable(_) => DataType::LargeBinary,
1086        Expr::FunctionCall { name, args, .. } => match name.to_lowercase().as_str() {
1087            "count" => DataType::Int64,
1088            "sum" | "avg" => {
1089                if let Some(arg) = args.first() {
1090                    let arg_type = infer_expr_type(arg, schema_info);
1091                    if matches!(arg_type, DataType::Float32 | DataType::Float64) {
1092                        DataType::Float64
1093                    } else {
1094                        DataType::Int64
1095                    }
1096                } else {
1097                    DataType::Int64
1098                }
1099            }
1100            "min" | "max" => {
1101                if let Some(arg) = args.first() {
1102                    infer_expr_type(arg, schema_info)
1103                } else {
1104                    DataType::LargeBinary
1105                }
1106            }
1107            "tostring" | "trim" | "ltrim" | "rtrim" | "tolower" | "toupper" | "left" | "right"
1108            | "substring" | "replace" | "reverse" | "type" => DataType::Utf8,
1109            "tointeger" | "toint" | "size" | "length" | "id" => DataType::Int64,
1110            "tofloat" => DataType::Float64,
1111            "toboolean" => DataType::Boolean,
1112            _ => DataType::LargeBinary,
1113        },
1114        _ => DataType::LargeBinary,
1115    }
1116}
1117
1118/// Numeric type promotion for binary arithmetic.
1119fn numeric_promotion(left: &DataType, right: &DataType) -> DataType {
1120    match (left, right) {
1121        (DataType::Float64, _) | (_, DataType::Float64) => DataType::Float64,
1122        (DataType::Float32, _) | (_, DataType::Float32) => DataType::Float64,
1123        (DataType::Int64, _) | (_, DataType::Int64) => DataType::Int64,
1124        (DataType::Int32, _) | (_, DataType::Int32) => DataType::Int64,
1125        _ => DataType::Int64,
1126    }
1127}
1128
1129/// Evaluate a simple expression to get a `uni_common::Value`.
1130///
1131/// Supports:
1132/// - Literal values
1133/// - Parameter references ($param)
1134/// - Variable references (node/edge variables from MATCH)
1135/// - Literal lists
1136/// - Literal maps ({key: value, ...})
1137pub(crate) fn evaluate_simple_expr(
1138    expr: &Expr,
1139    params: &HashMap<String, Value>,
1140    outer_values: &HashMap<String, Value>,
1141) -> DFResult<Value> {
1142    match expr {
1143        Expr::Literal(lit) => Ok(lit.to_value()),
1144
1145        Expr::Parameter(name) => params.get(name).cloned().ok_or_else(|| {
1146            datafusion::error::DataFusionError::Execution(format!("Parameter '{}' not found", name))
1147        }),
1148
1149        Expr::Variable(name) => {
1150            // Node variables are stored as "{name}._vid" in outer_values
1151            let vid_key = format!("{}._vid", name);
1152            if let Some(val) = outer_values.get(&vid_key) {
1153                return Ok(val.clone());
1154            }
1155            // Fall back to plain name (edge variables, scalar columns)
1156            outer_values.get(name).cloned().ok_or_else(|| {
1157                datafusion::error::DataFusionError::Execution(format!(
1158                    "Variable '{}' not found in scope (looked for '{}' and '{}')",
1159                    name, vid_key, name
1160                ))
1161            })
1162        }
1163
1164        Expr::List(items) => {
1165            let values: Vec<Value> = items
1166                .iter()
1167                .map(|item| evaluate_simple_expr(item, params, outer_values))
1168                .collect::<DFResult<_>>()?;
1169            Ok(Value::List(values))
1170        }
1171
1172        Expr::Map(entries) => {
1173            let map: HashMap<String, Value> = entries
1174                .iter()
1175                .map(|(k, v)| {
1176                    evaluate_simple_expr(v, params, outer_values).map(|val| (k.clone(), val))
1177                })
1178                .collect::<DFResult<_>>()?;
1179            Ok(Value::Map(map))
1180        }
1181
1182        _ => Err(datafusion::error::DataFusionError::Execution(format!(
1183            "Unsupported expression type for procedure argument: {:?}",
1184            expr
1185        ))),
1186    }
1187}
1188
1189/// Merge edge property metadata across multiple edge types.
1190///
1191/// When a traversal spans several edge types, property columns must accommodate
1192/// all of them. This function collects property metadata from each type and
1193/// resolves conflicts: if two types define the same property with different
1194/// data types, the merged type widens to `CypherValue`. Nullability is merged
1195/// with OR (if either is nullable, the result is nullable).
1196pub fn merged_edge_schema_props(
1197    uni_schema: &UniSchema,
1198    edge_type_ids: &[u32],
1199) -> HashMap<String, uni_common::core::schema::PropertyMeta> {
1200    let mut merged: HashMap<String, uni_common::core::schema::PropertyMeta> = HashMap::new();
1201    let mut sorted_ids = edge_type_ids.to_vec();
1202    sorted_ids.sort_unstable();
1203
1204    for edge_type_id in sorted_ids {
1205        if let Some(edge_type_name) = uni_schema.edge_type_name_by_id_unified(edge_type_id)
1206            && let Some(props) = uni_schema.properties.get(edge_type_name.as_str())
1207        {
1208            for (prop_name, meta) in props {
1209                match merged.get_mut(prop_name) {
1210                    Some(existing) => {
1211                        if existing.r#type != meta.r#type {
1212                            existing.r#type = uni_common::core::schema::DataType::CypherValue;
1213                        }
1214                        existing.nullable |= meta.nullable;
1215                    }
1216                    None => {
1217                        merged.insert(prop_name.clone(), meta.clone());
1218                    }
1219                }
1220            }
1221        }
1222    }
1223
1224    merged
1225}
1226
1227// ---------------------------------------------------------------------------
1228// Shared key extraction for Locy operators (Priority, Fold, BestBy, Fixpoint)
1229// ---------------------------------------------------------------------------
1230
1231/// A hashable scalar key extracted from an Arrow array row.
1232///
1233/// Used across Locy operators for grouping and deduplication.
1234#[derive(Clone, Debug, PartialEq, Eq, Hash)]
1235pub(crate) enum ScalarKey {
1236    Null,
1237    Bool(bool),
1238    Int64(i64),
1239    Utf8(String),
1240    Binary(Vec<u8>),
1241}
1242
1243/// Extract a composite key from a row of a `RecordBatch`.
1244///
1245/// For each column index in `key_indices`, reads the scalar value at `row_idx`
1246/// and converts it to a `ScalarKey`. Float64 values are hashed by their bit
1247/// representation for exact grouping.
1248pub(crate) fn extract_scalar_key(
1249    batch: &RecordBatch,
1250    key_indices: &[usize],
1251    row_idx: usize,
1252) -> Vec<ScalarKey> {
1253    use arrow::array::Array;
1254    key_indices
1255        .iter()
1256        .map(|&col_idx| {
1257            let col = batch.column(col_idx);
1258            if col.is_null(row_idx) {
1259                return ScalarKey::Null;
1260            }
1261            match col.data_type() {
1262                arrow_schema::DataType::Boolean => {
1263                    let arr = col
1264                        .as_any()
1265                        .downcast_ref::<arrow_array::BooleanArray>()
1266                        .unwrap();
1267                    ScalarKey::Bool(arr.value(row_idx))
1268                }
1269                arrow_schema::DataType::Int64 => {
1270                    let arr = col
1271                        .as_any()
1272                        .downcast_ref::<arrow_array::Int64Array>()
1273                        .unwrap();
1274                    ScalarKey::Int64(arr.value(row_idx))
1275                }
1276                arrow_schema::DataType::Utf8 => {
1277                    let arr = col
1278                        .as_any()
1279                        .downcast_ref::<arrow_array::StringArray>()
1280                        .unwrap();
1281                    ScalarKey::Utf8(arr.value(row_idx).to_string())
1282                }
1283                arrow_schema::DataType::LargeBinary => {
1284                    let arr = col
1285                        .as_any()
1286                        .downcast_ref::<arrow_array::LargeBinaryArray>()
1287                        .unwrap();
1288                    ScalarKey::Binary(arr.value(row_idx).to_vec())
1289                }
1290                arrow_schema::DataType::Float64 => {
1291                    // Hash f64 as bits for grouping
1292                    let arr = col
1293                        .as_any()
1294                        .downcast_ref::<arrow_array::Float64Array>()
1295                        .unwrap();
1296                    ScalarKey::Int64(arr.value(row_idx).to_bits() as i64)
1297                }
1298                arrow_schema::DataType::LargeUtf8 => {
1299                    let arr = col
1300                        .as_any()
1301                        .downcast_ref::<arrow_array::LargeStringArray>()
1302                        .unwrap();
1303                    ScalarKey::Utf8(arr.value(row_idx).to_string())
1304                }
1305                _ => {
1306                    // Fallback (including Struct): use arrow display formatter
1307                    let formatter = arrow::util::display::ArrayFormatter::try_new(
1308                        col.as_ref(),
1309                        &arrow::util::display::FormatOptions::default(),
1310                    );
1311                    match formatter {
1312                        Ok(f) => ScalarKey::Utf8(f.value(row_idx).to_string()),
1313                        Err(_) => ScalarKey::Utf8(format!("opaque@{row_idx}")),
1314                    }
1315                }
1316            }
1317        })
1318        .collect()
1319}
1320
1321/// Convert a raw distance value into a normalised similarity score.
1322///
1323/// The conversion depends on the distance metric:
1324/// - **Cosine**: `(2 - d) / 2` (LanceDB cosine distance ranges 0..2)
1325/// - **Dot**: pass-through (already a similarity measure)
1326/// - **L2** and others: `1 / (1 + d)`
1327pub fn calculate_score(distance: f32, metric: &DistanceMetric) -> f32 {
1328    match metric {
1329        DistanceMetric::Cosine => (2.0 - distance) / 2.0,
1330        DistanceMetric::Dot => distance,
1331        _ => 1.0 / (1.0 + distance),
1332    }
1333}
1334
1335#[cfg(test)]
1336mod tests {
1337    use super::*;
1338    use arrow_array::{LargeBinaryArray, UInt64Array};
1339    use arrow_schema::Schema;
1340
1341    #[test]
1342    fn test_extract_row_params_loses_uint64_to_int() {
1343        let schema = Arc::new(Schema::new(vec![Field::new(
1344            "n._vid",
1345            DataType::UInt64,
1346            true,
1347        )]));
1348        let batch = RecordBatch::try_new(schema, vec![Arc::new(UInt64Array::from(vec![Some(7)]))])
1349            .expect("batch should be valid");
1350
1351        let params = extract_row_params(&batch, 0);
1352        assert_eq!(params.get("n._vid"), Some(&Value::Int(7)));
1353    }
1354
1355    #[test]
1356    fn test_extract_row_params_decodes_largebinary_to_map() {
1357        let encoded = uni_common::cypher_value_codec::encode(&Value::Map(HashMap::new()));
1358        let schema = Arc::new(Schema::new(vec![Field::new(
1359            "m._all_props",
1360            DataType::LargeBinary,
1361            true,
1362        )]));
1363        let batch = RecordBatch::try_new(
1364            schema,
1365            vec![Arc::new(LargeBinaryArray::from(vec![Some(
1366                encoded.as_slice(),
1367            )]))],
1368        )
1369        .expect("batch should be valid");
1370
1371        let params = extract_row_params(&batch, 0);
1372        assert_eq!(
1373            params.get("m._all_props"),
1374            Some(&Value::Map(HashMap::new()))
1375        );
1376    }
1377}