uni_query/procedures_plugin/
vector.rs1use std::future::Future;
7use std::sync::Arc;
8use std::sync::OnceLock;
9
10use arrow_array::RecordBatch;
11use arrow_schema::{DataType, Field, Schema, SchemaRef};
12use datafusion::error::Result as DFResult;
13use datafusion::execution::SendableRecordBatchStream;
14use datafusion::logical_expr::ColumnarValue;
15use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
16use futures::stream;
17use uni_common::Value;
18use uni_plugin::traits::procedure::{
19 NamedArgType, ProcedureContext, ProcedureMode, ProcedurePlugin, ProcedureSignature,
20};
21use uni_plugin::traits::scalar::ArgType;
22use uni_plugin::{FnError, PluginError, PluginRegistrar, QName, SideEffects};
23
24use crate::procedures_plugin::host_args::{columnar_args_to_values, require_host};
25use crate::query::df_graph::search_procedures::run_vector_query;
26use crate::query::executor::procedure_host::QueryProcedureHost;
27
28fn signature() -> &'static ProcedureSignature {
31 static SIG: OnceLock<ProcedureSignature> = OnceLock::new();
32 SIG.get_or_init(|| ProcedureSignature {
33 args: vec![
34 NamedArgType {
35 name: smol_str::SmolStr::new("label"),
36 ty: ArgType::Primitive(DataType::Utf8),
37 default: None,
38 doc: "Vertex label to search.".to_owned(),
39 },
40 NamedArgType {
41 name: smol_str::SmolStr::new("property"),
42 ty: ArgType::Primitive(DataType::Utf8),
43 default: None,
44 doc: "Vector property name on the label.".to_owned(),
45 },
46 NamedArgType {
47 name: smol_str::SmolStr::new("query"),
48 ty: ArgType::CypherValue,
49 default: None,
50 doc: "Query vector (List<Float>) or query text (String, auto-embedded).".to_owned(),
51 },
52 NamedArgType {
53 name: smol_str::SmolStr::new("k"),
54 ty: ArgType::Primitive(DataType::Int64),
55 default: None,
56 doc: "Number of nearest neighbours to return.".to_owned(),
57 },
58 NamedArgType {
59 name: smol_str::SmolStr::new("filter"),
60 ty: ArgType::Primitive(DataType::Utf8),
61 default: None,
62 doc: "Optional pushdown filter expression.".to_owned(),
63 },
64 NamedArgType {
65 name: smol_str::SmolStr::new("threshold"),
66 ty: ArgType::Primitive(DataType::Float64),
67 default: None,
68 doc: "Optional maximum distance threshold (post-filter).".to_owned(),
69 },
70 NamedArgType {
71 name: smol_str::SmolStr::new("options"),
72 ty: ArgType::CypherValue,
73 default: None,
74 doc: "Optional reranker / extra options map.".to_owned(),
75 },
76 ],
77 yields: vector_query_yields(),
78 mode: ProcedureMode::Read,
79 side_effects: SideEffects::ReadOnly,
80 retry_contract: None,
81 batch_input: None,
82 docs:
83 "Approximate-nearest-neighbour over a vector index with optional cross-encoder rerank."
84 .to_owned(),
85 })
86}
87
88fn vector_query_yields() -> Vec<Field> {
90 vec![
91 vid_field(),
92 Field::new("distance", DataType::Float64, true),
93 Field::new("score", DataType::Float32, true),
94 Field::new("rerank_score", DataType::Float32, true),
95 ]
96}
97
98pub(super) fn fts_query_yields() -> Vec<Field> {
101 vec![
102 vid_field(),
103 Field::new("score", DataType::Float32, true),
104 Field::new("rerank_score", DataType::Float32, true),
105 ]
106}
107
108pub(super) fn hybrid_search_yields() -> Vec<Field> {
111 vec![
112 vid_field(),
113 Field::new("score", DataType::Float32, true),
114 Field::new("rerank_score", DataType::Float32, true),
115 Field::new("vector_score", DataType::Float32, true),
116 Field::new("fts_score", DataType::Float32, true),
117 Field::new("sparse_score", DataType::Float32, true),
118 Field::new("distance", DataType::Float64, true),
119 ]
120}
121
122fn vid_field() -> Field {
129 let mut md = std::collections::HashMap::new();
130 md.insert("_yield_kind".to_owned(), "node_vid_source".to_owned());
131 Field::new("vid", DataType::Int64, true).with_metadata(md)
132}
133
134#[derive(Debug)]
135struct VectorQueryProc;
136
137impl ProcedurePlugin for VectorQueryProc {
138 fn signature(&self) -> &ProcedureSignature {
139 signature()
140 }
141
142 fn invoke(
143 &self,
144 ctx: ProcedureContext<'_>,
145 args: &[ColumnarValue],
146 ) -> Result<SendableRecordBatchStream, FnError> {
147 run_search_procedure(
148 "uni.vector.query",
149 &ctx,
150 args,
151 signature(),
152 |host, uni_args, yield_items, output_schema| async move {
153 let target_properties = host.target_properties().clone();
154 run_vector_query(
155 &host,
156 &uni_args,
157 &yield_items,
158 &target_properties,
159 &output_schema,
160 )
161 .await
162 },
163 )
164 }
165}
166
167pub(super) fn resolve_yields_and_schema(
173 host: &crate::query::executor::procedure_host::QueryProcedureHost,
174 sig: &ProcedureSignature,
175 fallback_schema: &Arc<Schema>,
176) -> (Vec<(String, Option<String>)>, Arc<Schema>) {
177 let host_yields = host.yield_items();
178 if host_yields.is_empty() {
179 let yield_items: Vec<(String, Option<String>)> = sig
180 .yields
181 .iter()
182 .map(|f| (f.name().clone(), None))
183 .collect();
184 (yield_items, fallback_schema.clone())
185 } else {
186 let output_schema = host
187 .expected_schema()
188 .cloned()
189 .unwrap_or_else(|| fallback_schema.clone());
190 (host_yields.to_vec(), output_schema)
191 }
192}
193
194pub(super) fn run_search_procedure<F, Fut>(
202 proc_name: &'static str,
203 ctx: &ProcedureContext<'_>,
204 args: &[ColumnarValue],
205 sig: &'static ProcedureSignature,
206 run_fn: F,
207) -> Result<SendableRecordBatchStream, FnError>
208where
209 F: FnOnce(QueryProcedureHost, Vec<Value>, Vec<(String, Option<String>)>, SchemaRef) -> Fut
210 + Send
211 + 'static,
212 Fut: Future<Output = DFResult<Option<RecordBatch>>> + Send + 'static,
213{
214 let host = require_host(ctx, proc_name)?.clone();
215 let uni_args = columnar_args_to_values(args);
216 let fallback_schema = Arc::new(Schema::new(sig.yields.clone()));
217 let (yield_items, output_schema) = resolve_yields_and_schema(&host, sig, &fallback_schema);
218
219 let stream_schema = output_schema.clone();
220 let stream = stream::once(async move {
221 let batch = run_fn(host, uni_args, yield_items, output_schema.clone())
222 .await?
223 .unwrap_or_else(|| RecordBatch::new_empty(output_schema.clone()));
224 Ok::<_, datafusion::error::DataFusionError>(batch)
225 });
226 Ok(Box::pin(RecordBatchStreamAdapter::new(
227 stream_schema,
228 stream,
229 )))
230}
231
232pub fn register_into(r: &mut PluginRegistrar<'_>) -> Result<(), PluginError> {
238 r.procedure(
239 QName::new("uni", "vector.query"),
240 signature().clone(),
241 Arc::new(VectorQueryProc),
242 )?;
243 Ok(())
244}