uni_query/procedures_plugin/
sparse.rs1use std::sync::Arc;
7use std::sync::OnceLock;
8
9use arrow_schema::DataType;
10use datafusion::execution::SendableRecordBatchStream;
11use datafusion::logical_expr::ColumnarValue;
12use uni_plugin::traits::procedure::{
13 NamedArgType, ProcedureContext, ProcedureMode, ProcedurePlugin, ProcedureSignature,
14};
15use uni_plugin::traits::scalar::ArgType;
16use uni_plugin::{FnError, PluginError, PluginRegistrar, QName, SideEffects};
17
18use crate::procedures_plugin::vector::{fts_query_yields, run_search_procedure};
19use crate::query::df_graph::search_procedures::run_sparse_query;
20
21fn signature() -> &'static ProcedureSignature {
24 static SIG: OnceLock<ProcedureSignature> = OnceLock::new();
25 SIG.get_or_init(|| ProcedureSignature {
26 args: vec![
27 NamedArgType {
28 name: smol_str::SmolStr::new("label"),
29 ty: ArgType::Primitive(DataType::Utf8),
30 default: None,
31 doc: "Vertex label to search.".to_owned(),
32 },
33 NamedArgType {
34 name: smol_str::SmolStr::new("property"),
35 ty: ArgType::Primitive(DataType::Utf8),
36 default: None,
37 doc: "Sparse-vector property name on the label.".to_owned(),
38 },
39 NamedArgType {
40 name: smol_str::SmolStr::new("query"),
41 ty: ArgType::CypherValue,
42 default: None,
43 doc: "Query sparse vector ({indices, values}).".to_owned(),
44 },
45 NamedArgType {
46 name: smol_str::SmolStr::new("k"),
47 ty: ArgType::Primitive(DataType::Int64),
48 default: None,
49 doc: "Number of top hits to return.".to_owned(),
50 },
51 NamedArgType {
52 name: smol_str::SmolStr::new("filter"),
53 ty: ArgType::Primitive(DataType::Utf8),
54 default: None,
55 doc: "Optional pushdown filter expression.".to_owned(),
56 },
57 NamedArgType {
58 name: smol_str::SmolStr::new("threshold"),
59 ty: ArgType::Primitive(DataType::Float64),
60 default: None,
61 doc: "Optional minimum dot-score threshold (post-filter).".to_owned(),
62 },
63 NamedArgType {
64 name: smol_str::SmolStr::new("options"),
65 ty: ArgType::CypherValue,
66 default: None,
67 doc: "Optional extra options map (e.g. over_fetch).".to_owned(),
68 },
69 ],
70 yields: fts_query_yields(),
73 mode: ProcedureMode::Read,
74 side_effects: SideEffects::ReadOnly,
75 retry_contract: None,
76 batch_input: None,
77 docs: "Scored sparse-vector (SPLADE / learned-sparse) retrieval by dot product, \
78 MVCC/L0-aware via exact re-scoring."
79 .to_owned(),
80 })
81}
82
83#[derive(Debug)]
84struct SparseQueryProc;
85
86impl ProcedurePlugin for SparseQueryProc {
87 fn signature(&self) -> &ProcedureSignature {
88 signature()
89 }
90
91 fn invoke(
92 &self,
93 ctx: ProcedureContext<'_>,
94 args: &[ColumnarValue],
95 ) -> Result<SendableRecordBatchStream, FnError> {
96 run_search_procedure(
97 "uni.sparse.query",
98 &ctx,
99 args,
100 signature(),
101 |host, uni_args, yield_items, output_schema| async move {
102 let target_properties = host.target_properties().clone();
103 run_sparse_query(
104 &host,
105 &uni_args,
106 &yield_items,
107 &target_properties,
108 &output_schema,
109 )
110 .await
111 },
112 )
113 }
114}
115
116pub fn register_into(r: &mut PluginRegistrar<'_>) -> Result<(), PluginError> {
122 r.procedure(
123 QName::new("uni", "sparse.query"),
124 signature().clone(),
125 Arc::new(SparseQueryProc),
126 )?;
127 Ok(())
128}