Skip to main content

uni_query/procedures_plugin/
vector.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! `uni.vector.query` — k-nearest-neighbor over a vector index.
5
6use std::future::Future;
7use std::sync::Arc;
8use std::sync::OnceLock;
9
10use arrow_array::RecordBatch;
11use arrow_schema::{DataType, Field, Schema, SchemaRef};
12use datafusion::error::Result as DFResult;
13use datafusion::execution::SendableRecordBatchStream;
14use datafusion::logical_expr::ColumnarValue;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use futures::stream;
17use uni_common::Value;
18use uni_plugin::traits::procedure::{
19    NamedArgType, ProcedureContext, ProcedureMode, ProcedurePlugin, ProcedureSignature,
20};
21use uni_plugin::traits::scalar::ArgType;
22use uni_plugin::{FnError, PluginError, PluginRegistrar, QName, SideEffects};
23
24use crate::procedures_plugin::host_args::{columnar_args_to_values, require_host};
25use crate::query::df_graph::search_procedures::run_vector_query;
26use crate::query::executor::procedure_host::QueryProcedureHost;
27
28// Rust guideline compliant
29
30fn signature() -> &'static ProcedureSignature {
31    static SIG: OnceLock<ProcedureSignature> = OnceLock::new();
32    SIG.get_or_init(|| ProcedureSignature {
33        args: vec![
34            NamedArgType {
35                name: smol_str::SmolStr::new("label"),
36                ty: ArgType::Primitive(DataType::Utf8),
37                default: None,
38                doc: "Vertex label to search.".to_owned(),
39            },
40            NamedArgType {
41                name: smol_str::SmolStr::new("property"),
42                ty: ArgType::Primitive(DataType::Utf8),
43                default: None,
44                doc: "Vector property name on the label.".to_owned(),
45            },
46            NamedArgType {
47                name: smol_str::SmolStr::new("query"),
48                ty: ArgType::CypherValue,
49                default: None,
50                doc: "Query vector (List<Float>) or query text (String, auto-embedded).".to_owned(),
51            },
52            NamedArgType {
53                name: smol_str::SmolStr::new("k"),
54                ty: ArgType::Primitive(DataType::Int64),
55                default: None,
56                doc: "Number of nearest neighbours to return.".to_owned(),
57            },
58            NamedArgType {
59                name: smol_str::SmolStr::new("filter"),
60                ty: ArgType::Primitive(DataType::Utf8),
61                default: None,
62                doc: "Optional pushdown filter expression.".to_owned(),
63            },
64            NamedArgType {
65                name: smol_str::SmolStr::new("threshold"),
66                ty: ArgType::Primitive(DataType::Float64),
67                default: None,
68                doc: "Optional maximum distance threshold (post-filter).".to_owned(),
69            },
70            NamedArgType {
71                name: smol_str::SmolStr::new("options"),
72                ty: ArgType::CypherValue,
73                default: None,
74                doc: "Optional reranker / extra options map.".to_owned(),
75            },
76        ],
77        yields: vector_query_yields(),
78        mode: ProcedureMode::Read,
79        side_effects: SideEffects::ReadOnly,
80        retry_contract: None,
81        batch_input: None,
82        docs:
83            "Approximate-nearest-neighbour over a vector index with optional cross-encoder rerank."
84                .to_owned(),
85    })
86}
87
88/// Yield columns produced by `uni.vector.query`.
89fn vector_query_yields() -> Vec<Field> {
90    vec![
91        vid_field(),
92        Field::new("distance", DataType::Float64, true),
93        Field::new("score", DataType::Float32, true),
94        Field::new("rerank_score", DataType::Float32, true),
95    ]
96}
97
98/// Yield columns produced by `uni.fts.query` (no `distance` — BM25 has
99/// no distance metric).
100pub(super) fn fts_query_yields() -> Vec<Field> {
101    vec![
102        vid_field(),
103        Field::new("score", DataType::Float32, true),
104        Field::new("rerank_score", DataType::Float32, true),
105    ]
106}
107
108/// Yield columns produced by `uni.search` (hybrid — emits the full
109/// fused-score family).
110pub(super) fn hybrid_search_yields() -> Vec<Field> {
111    vec![
112        vid_field(),
113        Field::new("score", DataType::Float32, true),
114        Field::new("rerank_score", DataType::Float32, true),
115        Field::new("vector_score", DataType::Float32, true),
116        Field::new("fts_score", DataType::Float32, true),
117        Field::new("distance", DataType::Float64, true),
118    ]
119}
120
121/// Build the canonical `vid` field for search-style procedures and tag
122/// it with `_yield_kind = node_vid_source` so the planner's schema
123/// builder knows this procedure supports node-shaped YIELD expansion
124/// (`YIELD node` / `YIELD foo` projecting `<name>._vid + <name> +
125/// <name>._labels + <name>.<prop>` columns). The tag is the seam that
126/// replaced the procedure-name match arm in `procedure_call::build_schema`.
127fn vid_field() -> Field {
128    let mut md = std::collections::HashMap::new();
129    md.insert("_yield_kind".to_owned(), "node_vid_source".to_owned());
130    Field::new("vid", DataType::Int64, true).with_metadata(md)
131}
132
133#[derive(Debug)]
134struct VectorQueryProc;
135
136impl ProcedurePlugin for VectorQueryProc {
137    fn signature(&self) -> &ProcedureSignature {
138        signature()
139    }
140
141    fn invoke(
142        &self,
143        ctx: ProcedureContext<'_>,
144        args: &[ColumnarValue],
145    ) -> Result<SendableRecordBatchStream, FnError> {
146        run_search_procedure(
147            "uni.vector.query",
148            &ctx,
149            args,
150            signature(),
151            |host, uni_args, yield_items, output_schema| async move {
152                let target_properties = host.target_properties().clone();
153                run_vector_query(
154                    &host,
155                    &uni_args,
156                    &yield_items,
157                    &target_properties,
158                    &output_schema,
159                )
160                .await
161            },
162        )
163    }
164}
165
166/// Pick the right `(yield_items, output_schema)` for a search-plugin
167/// invocation. When the host carries planner state (composite query),
168/// honour it so node-shape yields expand correctly; otherwise fall back
169/// to the plugin's `signature.yields` (standalone CALL with no
170/// surrounding query plan, e.g. unit-test paths).
171pub(super) fn resolve_yields_and_schema(
172    host: &crate::query::executor::procedure_host::QueryProcedureHost,
173    sig: &ProcedureSignature,
174    fallback_schema: &Arc<Schema>,
175) -> (Vec<(String, Option<String>)>, Arc<Schema>) {
176    let host_yields = host.yield_items();
177    if host_yields.is_empty() {
178        let yield_items: Vec<(String, Option<String>)> = sig
179            .yields
180            .iter()
181            .map(|f| (f.name().clone(), None))
182            .collect();
183        (yield_items, fallback_schema.clone())
184    } else {
185        let output_schema = host
186            .expected_schema()
187            .cloned()
188            .unwrap_or_else(|| fallback_schema.clone());
189        (host_yields.to_vec(), output_schema)
190    }
191}
192
193/// Shared `ProcedurePlugin::invoke` body for the three host-coupled
194/// search procedures (`uni.vector.query`, `uni.fts.query`, `uni.search`).
195///
196/// They differ only in their procedure name, signature, and the `run_*`
197/// helper that produces the result batch; everything else (host
198/// down-cast, arg decode, yield/schema resolution, single-batch
199/// streaming) is identical.
200pub(super) fn run_search_procedure<F, Fut>(
201    proc_name: &'static str,
202    ctx: &ProcedureContext<'_>,
203    args: &[ColumnarValue],
204    sig: &'static ProcedureSignature,
205    run_fn: F,
206) -> Result<SendableRecordBatchStream, FnError>
207where
208    F: FnOnce(QueryProcedureHost, Vec<Value>, Vec<(String, Option<String>)>, SchemaRef) -> Fut
209        + Send
210        + 'static,
211    Fut: Future<Output = DFResult<Option<RecordBatch>>> + Send + 'static,
212{
213    let host = require_host(ctx, proc_name)?.clone();
214    let uni_args = columnar_args_to_values(args);
215    let fallback_schema = Arc::new(Schema::new(sig.yields.clone()));
216    let (yield_items, output_schema) = resolve_yields_and_schema(&host, sig, &fallback_schema);
217
218    let stream_schema = output_schema.clone();
219    let stream = stream::once(async move {
220        let batch = run_fn(host, uni_args, yield_items, output_schema.clone())
221            .await?
222            .unwrap_or_else(|| RecordBatch::new_empty(output_schema.clone()));
223        Ok::<_, datafusion::error::DataFusionError>(batch)
224    });
225    Ok(Box::pin(RecordBatchStreamAdapter::new(
226        stream_schema,
227        stream,
228    )))
229}
230
231/// Register `uni.vector.query` into `r`.
232///
233/// # Errors
234///
235/// Returns [`PluginError::DuplicateRegistration`] if a qname is taken.
236pub fn register_into(r: &mut PluginRegistrar<'_>) -> Result<(), PluginError> {
237    r.procedure(
238        QName::new("uni", "vector.query"),
239        signature().clone(),
240        Arc::new(VectorQueryProc),
241    )?;
242    Ok(())
243}