velesdb_server/handlers/search/
multi.rs1use axum::{
4 extract::{Path, State},
5 http::StatusCode,
6 response::IntoResponse,
7 Json,
8};
9use std::sync::Arc;
10
11use crate::types::{ErrorResponse, MultiQuerySearchRequest, SearchIdsResponse, SearchResponse};
12use crate::AppState;
13
14use super::pipeline::{
15 finish_search_ids_with_cb, finish_search_with_cb, id_score_results, parse_filter_or_400,
16 validate_query_dimension,
17};
18use super::workers::run_blocking_search;
19use crate::handlers::helpers::{apply_pre_check, extract_client_id, get_vector_collection_or_404};
20
21#[allow(clippy::result_large_err)]
24fn parse_fusion_strategy(
25 req: &MultiQuerySearchRequest,
26 state: &AppState,
27) -> Result<velesdb_core::FusionStrategy, axum::response::Response> {
28 use velesdb_core::FusionStrategy;
29 match req.strategy.to_lowercase().as_str() {
30 "average" | "avg" => Ok(FusionStrategy::Average),
31 "maximum" | "max" => Ok(FusionStrategy::Maximum),
32 "rrf" => Ok(FusionStrategy::RRF { k: req.rrf_k }),
33 "weighted" => Ok(FusionStrategy::Weighted {
34 avg_weight: req.avg_weight,
35 max_weight: req.max_weight,
36 hit_weight: req.hit_weight,
37 }),
38 "relative_score" | "rsf" => Ok(FusionStrategy::RelativeScore {
39 dense_weight: req.dense_weight,
40 sparse_weight: req.sparse_weight,
41 }),
42 _ => {
43 state.operational_metrics.inc_errors();
44 Err((
45 StatusCode::BAD_REQUEST,
46 Json(ErrorResponse {
47 error: format!(
48 "Invalid strategy: {}. Valid: average, maximum, rrf, weighted, \
49 relative_score",
50 req.strategy
51 ),
52 code: None,
53 }),
54 )
55 .into_response())
56 }
57 }
58}
59
60#[allow(clippy::result_large_err)]
63fn validate_query_vectors(
64 state: &AppState,
65 name: &str,
66 expected_dimension: usize,
67 vectors: &[Vec<f32>],
68) -> Result<(), axum::response::Response> {
69 for (idx, vector) in vectors.iter().enumerate() {
70 if let Err(error) = validate_query_dimension(state, name, expected_dimension, vector) {
71 state.operational_metrics.inc_errors();
72 return Err((
73 StatusCode::BAD_REQUEST,
74 Json(ErrorResponse {
75 error: format!("Invalid query vector at index {idx}: {}", error.error),
76 code: error.code.clone(),
77 }),
78 )
79 .into_response());
80 }
81 }
82 Ok(())
83}
84
85#[allow(clippy::result_large_err)]
90fn prepare_multi_query(
91 state: &AppState,
92 headers: &axum::http::HeaderMap,
93 name: &str,
94 req: &MultiQuerySearchRequest,
95) -> Result<
96 (
97 velesdb_core::collection::VectorCollection,
98 velesdb_core::FusionStrategy,
99 ),
100 axum::response::Response,
101> {
102 state.onboarding_metrics.record_search_request();
103
104 let collection = get_vector_collection_or_404(state, name)?;
105
106 state.operational_metrics.record_vector_query();
109
110 let client_id = extract_client_id(headers);
111 if let Err(resp) = apply_pre_check(collection.guard_rails(), &client_id) {
112 state.operational_metrics.inc_rate_limited();
113 return Err(resp);
114 }
115
116 let strategy = parse_fusion_strategy(req, state)?;
117
118 let expected_dimension = collection.config().dimension;
119 validate_query_vectors(state, name, expected_dimension, &req.vectors)?;
120
121 Ok((collection, strategy))
122}
123
124#[utoipa::path(
126 post,
127 path = "/collections/{name}/search/multi",
128 tag = "search",
129 params(("name" = String, Path, description = "Collection name")),
130 request_body = MultiQuerySearchRequest,
131 responses(
132 (status = 200, description = "Multi-query search results", body = SearchResponse),
133 (status = 404, description = "Collection not found", body = ErrorResponse),
134 (status = 500, description = "Internal server error", body = ErrorResponse)
135 )
136)]
137#[allow(clippy::result_large_err)]
138pub async fn multi_query_search(
139 State(state): State<Arc<AppState>>,
140 headers: axum::http::HeaderMap,
141 Path(name): Path<String>,
142 Json(req): Json<MultiQuerySearchRequest>,
143) -> impl IntoResponse {
144 let (collection, strategy) = match prepare_multi_query(&state, &headers, &name, &req) {
145 Ok(v) => v,
146 Err(resp) => return resp,
147 };
148
149 let filter = match req.filter.as_ref() {
157 Some(filter_json) => match parse_filter_or_400(filter_json, &state.onboarding_metrics) {
158 Ok(f) => Some(f),
159 Err(resp) => {
160 state.operational_metrics.inc_errors();
161 return resp;
162 }
163 },
164 None => None,
165 };
166
167 let start = std::time::Instant::now();
168
169 let collection_for_work = collection.clone();
177 let vectors = req.vectors;
178 let top_k = req.top_k;
179
180 let work_result = run_blocking_search(move || {
181 let query_refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
182 Ok(collection_for_work.multi_query_search(&query_refs, top_k, strategy, filter.as_ref()))
183 })
184 .await;
185
186 let search_result = match work_result {
187 Ok(inner) => inner,
188 Err(resp) => {
189 state.operational_metrics.inc_errors();
190 return resp;
191 }
192 };
193
194 finish_search_with_cb(&state, &name, start, &collection, search_result)
195}
196
197#[utoipa::path(
203 post,
204 path = "/collections/{name}/search/multi/ids",
205 tag = "search",
206 params(("name" = String, Path, description = "Collection name")),
207 request_body = MultiQuerySearchRequest,
208 responses(
209 (status = 200, description = "Multi-query ids-only results", body = SearchIdsResponse),
210 (status = 400, description = "Invalid request", body = ErrorResponse),
211 (status = 404, description = "Collection not found", body = ErrorResponse)
212 )
213)]
214#[allow(clippy::result_large_err)]
215pub async fn multi_query_search_ids(
216 State(state): State<Arc<AppState>>,
217 headers: axum::http::HeaderMap,
218 Path(name): Path<String>,
219 Json(req): Json<MultiQuerySearchRequest>,
220) -> impl IntoResponse {
221 let (collection, strategy) = match prepare_multi_query(&state, &headers, &name, &req) {
222 Ok(v) => v,
223 Err(resp) => return resp,
224 };
225
226 if req.filter.is_some() {
229 state.operational_metrics.inc_errors();
230 return (
231 StatusCode::BAD_REQUEST,
232 Json(ErrorResponse {
233 error: "Metadata filters are not supported on /search/multi/ids; \
234 use /search/multi for filtered multi-query search."
235 .to_string(),
236 code: None,
237 }),
238 )
239 .into_response();
240 }
241
242 let start = std::time::Instant::now();
243 let collection_for_work = collection.clone();
244 let vectors = req.vectors;
245 let top_k = req.top_k;
246
247 let work_result = run_blocking_search(move || {
248 let query_refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
249 Ok(collection_for_work
250 .multi_query_search_ids(&query_refs, top_k, strategy)
251 .map(id_score_results))
252 })
253 .await;
254
255 let search_result = match work_result {
256 Ok(inner) => inner,
257 Err(resp) => {
258 state.operational_metrics.inc_errors();
259 return resp;
260 }
261 };
262
263 finish_search_ids_with_cb(&state, &name, start, &collection, search_result)
264}