Skip to main content

velesdb_server/handlers/search/
multi.rs

1//! Multi-query search handler: fuse results from multiple query vectors.
2
3use axum::{
4    extract::{Path, State},
5    http::StatusCode,
6    response::IntoResponse,
7    Json,
8};
9use std::sync::Arc;
10
11use crate::types::{ErrorResponse, MultiQuerySearchRequest, SearchIdsResponse, SearchResponse};
12use crate::AppState;
13
14use super::pipeline::{
15    finish_search_ids_with_cb, finish_search_with_cb, id_score_results, parse_filter_or_400,
16    validate_query_dimension,
17};
18use super::workers::run_blocking_search;
19use crate::handlers::helpers::{apply_pre_check, extract_client_id, get_vector_collection_or_404};
20
21/// Parse the fusion strategy name into a `FusionStrategy`, returning a 400
22/// response (and bumping the error counter) for an unknown strategy.
23#[allow(clippy::result_large_err)]
24fn parse_fusion_strategy(
25    req: &MultiQuerySearchRequest,
26    state: &AppState,
27) -> Result<velesdb_core::FusionStrategy, axum::response::Response> {
28    use velesdb_core::FusionStrategy;
29    match req.strategy.to_lowercase().as_str() {
30        "average" | "avg" => Ok(FusionStrategy::Average),
31        "maximum" | "max" => Ok(FusionStrategy::Maximum),
32        "rrf" => Ok(FusionStrategy::RRF { k: req.rrf_k }),
33        "weighted" => Ok(FusionStrategy::Weighted {
34            avg_weight: req.avg_weight,
35            max_weight: req.max_weight,
36            hit_weight: req.hit_weight,
37        }),
38        "relative_score" | "rsf" => Ok(FusionStrategy::RelativeScore {
39            dense_weight: req.dense_weight,
40            sparse_weight: req.sparse_weight,
41        }),
42        _ => {
43            state.operational_metrics.inc_errors();
44            Err((
45                StatusCode::BAD_REQUEST,
46                Json(ErrorResponse {
47                    error: format!(
48                        "Invalid strategy: {}. Valid: average, maximum, rrf, weighted, \
49                         relative_score",
50                        req.strategy
51                    ),
52                    code: None,
53                }),
54            )
55                .into_response())
56        }
57    }
58}
59
60/// Validate every query vector's dimension, returning a 400 on the first
61/// mismatch (with the offending index in the message).
62#[allow(clippy::result_large_err)]
63fn validate_query_vectors(
64    state: &AppState,
65    name: &str,
66    expected_dimension: usize,
67    vectors: &[Vec<f32>],
68) -> Result<(), axum::response::Response> {
69    for (idx, vector) in vectors.iter().enumerate() {
70        if let Err(error) = validate_query_dimension(state, name, expected_dimension, vector) {
71            state.operational_metrics.inc_errors();
72            return Err((
73                StatusCode::BAD_REQUEST,
74                Json(ErrorResponse {
75                    error: format!("Invalid query vector at index {idx}: {}", error.error),
76                    code: error.code.clone(),
77                }),
78            )
79                .into_response());
80        }
81    }
82    Ok(())
83}
84
85/// Shared preamble for the multi-query handlers: records metrics, resolves the
86/// collection, enforces guard rails, parses the fusion strategy, and validates
87/// every query vector's dimension. Returns the collection and strategy on
88/// success, or the appropriate error response.
89#[allow(clippy::result_large_err)]
90fn prepare_multi_query(
91    state: &AppState,
92    headers: &axum::http::HeaderMap,
93    name: &str,
94    req: &MultiQuerySearchRequest,
95) -> Result<
96    (
97        velesdb_core::collection::VectorCollection,
98        velesdb_core::FusionStrategy,
99    ),
100    axum::response::Response,
101> {
102    state.onboarding_metrics.record_search_request();
103
104    let collection = get_vector_collection_or_404(state, name)?;
105
106    // Record query type only after confirming the collection exists, so
107    // 404s do not inflate queries_total or vector_queries.
108    state.operational_metrics.record_vector_query();
109
110    let client_id = extract_client_id(headers);
111    if let Err(resp) = apply_pre_check(collection.guard_rails(), &client_id) {
112        state.operational_metrics.inc_rate_limited();
113        return Err(resp);
114    }
115
116    let strategy = parse_fusion_strategy(req, state)?;
117
118    let expected_dimension = collection.config().dimension;
119    validate_query_vectors(state, name, expected_dimension, &req.vectors)?;
120
121    Ok((collection, strategy))
122}
123
124/// Multi-query search with fusion strategies.
125#[utoipa::path(
126    post,
127    path = "/collections/{name}/search/multi",
128    tag = "search",
129    params(("name" = String, Path, description = "Collection name")),
130    request_body = MultiQuerySearchRequest,
131    responses(
132        (status = 200, description = "Multi-query search results", body = SearchResponse),
133        (status = 404, description = "Collection not found", body = ErrorResponse),
134        (status = 500, description = "Internal server error", body = ErrorResponse)
135    )
136)]
137#[allow(clippy::result_large_err)]
138pub async fn multi_query_search(
139    State(state): State<Arc<AppState>>,
140    headers: axum::http::HeaderMap,
141    Path(name): Path<String>,
142    Json(req): Json<MultiQuerySearchRequest>,
143) -> impl IntoResponse {
144    let (collection, strategy) = match prepare_multi_query(&state, &headers, &name, &req) {
145        Ok(v) => v,
146        Err(resp) => return resp,
147    };
148
149    // Parse the optional metadata filter. We need to materialise the
150    // `Filter` before starting the stopwatch so that a malformed filter
151    // yields a 400 response instead of a misleading 200 with unfiltered
152    // results. Regression guard: see
153    // `test_multi_query_search_with_filter_excludes_nonmatching_points`
154    // and `test_multi_query_search_with_invalid_filter_returns_400`
155    // (F-04).
156    let filter = match req.filter.as_ref() {
157        Some(filter_json) => match parse_filter_or_400(filter_json, &state.onboarding_metrics) {
158            Ok(f) => Some(f),
159            Err(resp) => {
160                state.operational_metrics.inc_errors();
161                return resp;
162            }
163        },
164        None => None,
165    };
166
167    let start = std::time::Instant::now();
168
169    // F-01 sweep: multi-vector fusion is CPU-bound (multiple HNSW
170    // passes plus a fusion step) and was previously executed on the
171    // async runtime thread. Move it to a blocking worker so concurrent
172    // requests stay responsive. We move `vectors` (owned
173    // `Vec<Vec<f32>>`) into the closure and rebuild the `&[f32]` slice
174    // view inside, because `spawn_blocking` requires a 'static closure
175    // and borrowed slice references cannot cross the boundary.
176    let collection_for_work = collection.clone();
177    let vectors = req.vectors;
178    let top_k = req.top_k;
179
180    let work_result = run_blocking_search(move || {
181        let query_refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
182        Ok(collection_for_work.multi_query_search(&query_refs, top_k, strategy, filter.as_ref()))
183    })
184    .await;
185
186    let search_result = match work_result {
187        Ok(inner) => inner,
188        Err(resp) => {
189            state.operational_metrics.inc_errors();
190            return resp;
191        }
192    };
193
194    finish_search_with_cb(&state, &name, start, &collection, search_result)
195}
196
197/// Multi-query fusion search returning only ids and scores (no payloads).
198///
199/// Faster than `/search/multi` when payloads are not needed: the core
200/// `multi_query_search_ids` kernel skips payload hydration. Metadata filters
201/// are not supported here — use `/search/multi` for filtered fusion.
202#[utoipa::path(
203    post,
204    path = "/collections/{name}/search/multi/ids",
205    tag = "search",
206    params(("name" = String, Path, description = "Collection name")),
207    request_body = MultiQuerySearchRequest,
208    responses(
209        (status = 200, description = "Multi-query ids-only results", body = SearchIdsResponse),
210        (status = 400, description = "Invalid request", body = ErrorResponse),
211        (status = 404, description = "Collection not found", body = ErrorResponse)
212    )
213)]
214#[allow(clippy::result_large_err)]
215pub async fn multi_query_search_ids(
216    State(state): State<Arc<AppState>>,
217    headers: axum::http::HeaderMap,
218    Path(name): Path<String>,
219    Json(req): Json<MultiQuerySearchRequest>,
220) -> impl IntoResponse {
221    let (collection, strategy) = match prepare_multi_query(&state, &headers, &name, &req) {
222        Ok(v) => v,
223        Err(resp) => return resp,
224    };
225
226    // The ids-only fusion kernel has no filter parameter. Reject filters
227    // explicitly rather than silently returning unfiltered results.
228    if req.filter.is_some() {
229        state.operational_metrics.inc_errors();
230        return (
231            StatusCode::BAD_REQUEST,
232            Json(ErrorResponse {
233                error: "Metadata filters are not supported on /search/multi/ids; \
234                        use /search/multi for filtered multi-query search."
235                    .to_string(),
236                code: None,
237            }),
238        )
239            .into_response();
240    }
241
242    let start = std::time::Instant::now();
243    let collection_for_work = collection.clone();
244    let vectors = req.vectors;
245    let top_k = req.top_k;
246
247    let work_result = run_blocking_search(move || {
248        let query_refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
249        Ok(collection_for_work
250            .multi_query_search_ids(&query_refs, top_k, strategy)
251            .map(id_score_results))
252    })
253    .await;
254
255    let search_result = match work_result {
256        Ok(inner) => inner,
257        Err(resp) => {
258            state.operational_metrics.inc_errors();
259            return resp;
260        }
261    };
262
263    finish_search_ids_with_cb(&state, &name, start, &collection, search_result)
264}