Skip to main content

uni_query/procedures_plugin/
sparse.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! `uni.sparse.query` — scored sparse-vector (SPLADE / learned-sparse) search.
5
6use std::sync::Arc;
7use std::sync::OnceLock;
8
9use arrow_schema::DataType;
10use datafusion::execution::SendableRecordBatchStream;
11use datafusion::logical_expr::ColumnarValue;
12use uni_plugin::traits::procedure::{
13    NamedArgType, ProcedureContext, ProcedureMode, ProcedurePlugin, ProcedureSignature,
14};
15use uni_plugin::traits::scalar::ArgType;
16use uni_plugin::{FnError, PluginError, PluginRegistrar, QName, SideEffects};
17
18use crate::procedures_plugin::vector::{fts_query_yields, run_search_procedure};
19use crate::query::df_graph::search_procedures::run_sparse_query;
20
21// Rust guideline compliant
22
23fn signature() -> &'static ProcedureSignature {
24    static SIG: OnceLock<ProcedureSignature> = OnceLock::new();
25    SIG.get_or_init(|| ProcedureSignature {
26        args: vec![
27            NamedArgType {
28                name: smol_str::SmolStr::new("label"),
29                ty: ArgType::Primitive(DataType::Utf8),
30                default: None,
31                doc: "Vertex label to search.".to_owned(),
32            },
33            NamedArgType {
34                name: smol_str::SmolStr::new("property"),
35                ty: ArgType::Primitive(DataType::Utf8),
36                default: None,
37                doc: "Sparse-vector property name on the label.".to_owned(),
38            },
39            NamedArgType {
40                name: smol_str::SmolStr::new("query"),
41                ty: ArgType::CypherValue,
42                default: None,
43                doc: "Query sparse vector ({indices, values}).".to_owned(),
44            },
45            NamedArgType {
46                name: smol_str::SmolStr::new("k"),
47                ty: ArgType::Primitive(DataType::Int64),
48                default: None,
49                doc: "Number of top hits to return.".to_owned(),
50            },
51            NamedArgType {
52                name: smol_str::SmolStr::new("filter"),
53                ty: ArgType::Primitive(DataType::Utf8),
54                default: None,
55                doc: "Optional pushdown filter expression.".to_owned(),
56            },
57            NamedArgType {
58                name: smol_str::SmolStr::new("threshold"),
59                ty: ArgType::Primitive(DataType::Float64),
60                default: None,
61                doc: "Optional minimum dot-score threshold (post-filter).".to_owned(),
62            },
63            NamedArgType {
64                name: smol_str::SmolStr::new("options"),
65                ty: ArgType::CypherValue,
66                default: None,
67                doc: "Optional extra options map (e.g. over_fetch).".to_owned(),
68            },
69        ],
70        // Sparse scoring is a dot product (similarity), so like FTS there is no
71        // `distance` column — only `score`/`rerank_score`.
72        yields: fts_query_yields(),
73        mode: ProcedureMode::Read,
74        side_effects: SideEffects::ReadOnly,
75        retry_contract: None,
76        batch_input: None,
77        docs: "Scored sparse-vector (SPLADE / learned-sparse) retrieval by dot product, \
78               MVCC/L0-aware via exact re-scoring."
79            .to_owned(),
80    })
81}
82
83#[derive(Debug)]
84struct SparseQueryProc;
85
86impl ProcedurePlugin for SparseQueryProc {
87    fn signature(&self) -> &ProcedureSignature {
88        signature()
89    }
90
91    fn invoke(
92        &self,
93        ctx: ProcedureContext<'_>,
94        args: &[ColumnarValue],
95    ) -> Result<SendableRecordBatchStream, FnError> {
96        run_search_procedure(
97            "uni.sparse.query",
98            &ctx,
99            args,
100            signature(),
101            |host, uni_args, yield_items, output_schema| async move {
102                let target_properties = host.target_properties().clone();
103                run_sparse_query(
104                    &host,
105                    &uni_args,
106                    &yield_items,
107                    &target_properties,
108                    &output_schema,
109                )
110                .await
111            },
112        )
113    }
114}
115
116/// Register `uni.sparse.query` into `r`.
117///
118/// # Errors
119///
120/// Returns [`PluginError::DuplicateRegistration`] if a qname is taken.
121pub fn register_into(r: &mut PluginRegistrar<'_>) -> Result<(), PluginError> {
122    r.procedure(
123        QName::new("uni", "sparse.query"),
124        signature().clone(),
125        Arc::new(SparseQueryProc),
126    )?;
127    Ok(())
128}