1use std::sync::Arc;
11
12use axum::extract::State;
13use axum::routing::{get, post};
14use axum::{Json, Router};
15use serde::{Deserialize, Serialize};
16use tokio::sync::Mutex;
17
18use rvf_runtime::{QueryOptions, RvfStore};
19
20use crate::error::ServerError;
21
22pub type SharedStore = Arc<Mutex<RvfStore>>;
24
25pub fn router(store: SharedStore) -> Router {
27 Router::new()
28 .route("/v1/ingest", post(ingest))
29 .route("/v1/query", post(query))
30 .route("/v1/delete", post(delete))
31 .route("/v1/status", get(status))
32 .route("/v1/health", get(health))
33 .with_state(store)
34}
35
36#[derive(Deserialize)]
39pub struct IngestRequest {
40 pub vectors: Vec<Vec<f32>>,
42 pub ids: Vec<u64>,
44 pub metadata: Option<Vec<MetadataEntryJson>>,
46}
47
48#[derive(Deserialize)]
49pub struct MetadataEntryJson {
50 pub field_id: u16,
51 pub value: MetadataValueJson,
52}
53
54#[derive(Deserialize)]
55#[serde(untagged)]
56pub enum MetadataValueJson {
57 U64(u64),
58 F64(f64),
59 String(String),
60}
61
62#[derive(Serialize, Deserialize)]
63pub struct IngestResponse {
64 pub accepted: u64,
65 pub rejected: u64,
66 pub epoch: u32,
67}
68
69#[derive(Deserialize)]
70pub struct QueryRequest {
71 pub vector: Vec<f32>,
73 pub k: usize,
75 pub ef_search: Option<u16>,
77}
78
79#[derive(Serialize, Deserialize)]
80pub struct QueryResponse {
81 pub results: Vec<QueryResultEntry>,
82}
83
84#[derive(Serialize, Deserialize)]
85pub struct QueryResultEntry {
86 pub id: u64,
87 pub distance: f32,
88}
89
90#[derive(Deserialize)]
91pub struct DeleteRequest {
92 pub ids: Vec<u64>,
94}
95
96#[derive(Serialize, Deserialize)]
97pub struct DeleteResponse {
98 pub deleted: u64,
99 pub epoch: u32,
100}
101
102#[derive(Serialize, Deserialize)]
103pub struct StatusResponse {
104 pub total_vectors: u64,
105 pub total_segments: u32,
106 pub file_size: u64,
107 pub current_epoch: u32,
108 pub profile_id: u8,
109 pub dead_space_ratio: f64,
110 pub read_only: bool,
111}
112
113#[derive(Serialize)]
114pub struct HealthResponse {
115 pub status: &'static str,
116}
117
118async fn ingest(
121 State(store): State<SharedStore>,
122 Json(req): Json<IngestRequest>,
123) -> Result<Json<IngestResponse>, ServerError> {
124 if req.vectors.len() != req.ids.len() {
125 return Err(ServerError::BadRequest(
126 "vectors and ids must have the same length".into(),
127 ));
128 }
129
130 let vec_refs: Vec<&[f32]> = req.vectors.iter().map(|v| v.as_slice()).collect();
131
132 let metadata: Option<Vec<rvf_runtime::MetadataEntry>> = req.metadata.map(|entries| {
133 entries
134 .into_iter()
135 .map(|e| rvf_runtime::MetadataEntry {
136 field_id: e.field_id,
137 value: match e.value {
138 MetadataValueJson::U64(v) => rvf_runtime::MetadataValue::U64(v),
139 MetadataValueJson::F64(v) => rvf_runtime::MetadataValue::F64(v),
140 MetadataValueJson::String(v) => rvf_runtime::MetadataValue::String(v),
141 },
142 })
143 .collect()
144 });
145
146 let result = {
147 let mut s = store.lock().await;
148 s.ingest_batch(
149 &vec_refs,
150 &req.ids,
151 metadata.as_deref(),
152 )?
153 };
154
155 Ok(Json(IngestResponse {
156 accepted: result.accepted,
157 rejected: result.rejected,
158 epoch: result.epoch,
159 }))
160}
161
162async fn query(
163 State(store): State<SharedStore>,
164 Json(req): Json<QueryRequest>,
165) -> Result<Json<QueryResponse>, ServerError> {
166 if req.k == 0 {
167 return Err(ServerError::BadRequest("k must be > 0".into()));
168 }
169
170 let opts = QueryOptions {
171 ef_search: req.ef_search.unwrap_or(100),
172 ..Default::default()
173 };
174
175 let results = {
176 let s = store.lock().await;
177 s.query(&req.vector, req.k, &opts)?
178 };
179
180 Ok(Json(QueryResponse {
181 results: results
182 .into_iter()
183 .map(|r| QueryResultEntry {
184 id: r.id,
185 distance: r.distance,
186 })
187 .collect(),
188 }))
189}
190
191async fn delete(
192 State(store): State<SharedStore>,
193 Json(req): Json<DeleteRequest>,
194) -> Result<Json<DeleteResponse>, ServerError> {
195 if req.ids.is_empty() {
196 return Err(ServerError::BadRequest("ids must not be empty".into()));
197 }
198
199 let result = {
200 let mut s = store.lock().await;
201 s.delete(&req.ids)?
202 };
203
204 Ok(Json(DeleteResponse {
205 deleted: result.deleted,
206 epoch: result.epoch,
207 }))
208}
209
210async fn status(
211 State(store): State<SharedStore>,
212) -> Result<Json<StatusResponse>, ServerError> {
213 let s = store.lock().await;
214 let st = s.status();
215
216 Ok(Json(StatusResponse {
217 total_vectors: st.total_vectors,
218 total_segments: st.total_segments,
219 file_size: st.file_size,
220 current_epoch: st.current_epoch,
221 profile_id: st.profile_id,
222 dead_space_ratio: st.dead_space_ratio,
223 read_only: st.read_only,
224 }))
225}
226
227async fn health() -> Json<HealthResponse> {
228 Json(HealthResponse { status: "ok" })
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use axum::body::Body;
235 use axum::http::{Request, StatusCode};
236 use rvf_runtime::RvfOptions;
237 use tempfile::TempDir;
238 use tower::ServiceExt;
239
240 fn create_test_store() -> (TempDir, SharedStore) {
241 let dir = TempDir::new().unwrap();
242 let path = dir.path().join("test.rvf");
243 let options = RvfOptions {
244 dimension: 4,
245 ..Default::default()
246 };
247 let store = RvfStore::create(&path, options).unwrap();
248 (dir, Arc::new(Mutex::new(store)))
249 }
250
251 #[tokio::test]
252 async fn test_health() {
253 let (_dir, store) = create_test_store();
254 let app = router(store);
255
256 let resp = app
257 .oneshot(
258 Request::builder()
259 .uri("/v1/health")
260 .body(Body::empty())
261 .unwrap(),
262 )
263 .await
264 .unwrap();
265
266 assert_eq!(resp.status(), StatusCode::OK);
267 }
268
269 #[tokio::test]
270 async fn test_status_empty_store() {
271 let (_dir, store) = create_test_store();
272 let app = router(store);
273
274 let resp = app
275 .oneshot(
276 Request::builder()
277 .uri("/v1/status")
278 .body(Body::empty())
279 .unwrap(),
280 )
281 .await
282 .unwrap();
283
284 assert_eq!(resp.status(), StatusCode::OK);
285
286 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
287 .await
288 .unwrap();
289 let status: StatusResponse = serde_json::from_slice(&body).unwrap();
290 assert_eq!(status.total_vectors, 0);
291 assert!(!status.read_only);
292 }
293
294 #[tokio::test]
295 async fn test_ingest_and_query() {
296 let (_dir, store) = create_test_store();
297 let app = router(store.clone());
298
299 let ingest_body = serde_json::json!({
301 "vectors": [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]],
302 "ids": [1, 2]
303 });
304
305 let resp = app
306 .oneshot(
307 Request::builder()
308 .method("POST")
309 .uri("/v1/ingest")
310 .header("content-type", "application/json")
311 .body(Body::from(serde_json::to_vec(&ingest_body).unwrap()))
312 .unwrap(),
313 )
314 .await
315 .unwrap();
316
317 assert_eq!(resp.status(), StatusCode::OK);
318 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
319 .await
320 .unwrap();
321 let ingest_resp: IngestResponse = serde_json::from_slice(&body).unwrap();
322 assert_eq!(ingest_resp.accepted, 2);
323 assert_eq!(ingest_resp.rejected, 0);
324
325 let app2 = router(store);
327 let query_body = serde_json::json!({
328 "vector": [1.0, 0.0, 0.0, 0.0],
329 "k": 2
330 });
331
332 let resp = app2
333 .oneshot(
334 Request::builder()
335 .method("POST")
336 .uri("/v1/query")
337 .header("content-type", "application/json")
338 .body(Body::from(serde_json::to_vec(&query_body).unwrap()))
339 .unwrap(),
340 )
341 .await
342 .unwrap();
343
344 assert_eq!(resp.status(), StatusCode::OK);
345 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
346 .await
347 .unwrap();
348 let query_resp: QueryResponse = serde_json::from_slice(&body).unwrap();
349 assert_eq!(query_resp.results.len(), 2);
350 assert_eq!(query_resp.results[0].id, 1);
351 assert!(query_resp.results[0].distance < f32::EPSILON);
352 }
353
354 #[tokio::test]
355 async fn test_ingest_and_delete() {
356 let (_dir, store) = create_test_store();
357 let app = router(store.clone());
358
359 let ingest_body = serde_json::json!({
361 "vectors": [
362 [1.0, 0.0, 0.0, 0.0],
363 [0.0, 1.0, 0.0, 0.0],
364 [0.0, 0.0, 1.0, 0.0]
365 ],
366 "ids": [10, 20, 30]
367 });
368
369 let resp = app
370 .oneshot(
371 Request::builder()
372 .method("POST")
373 .uri("/v1/ingest")
374 .header("content-type", "application/json")
375 .body(Body::from(serde_json::to_vec(&ingest_body).unwrap()))
376 .unwrap(),
377 )
378 .await
379 .unwrap();
380 assert_eq!(resp.status(), StatusCode::OK);
381
382 let app2 = router(store.clone());
384 let delete_body = serde_json::json!({ "ids": [20] });
385
386 let resp = app2
387 .oneshot(
388 Request::builder()
389 .method("POST")
390 .uri("/v1/delete")
391 .header("content-type", "application/json")
392 .body(Body::from(serde_json::to_vec(&delete_body).unwrap()))
393 .unwrap(),
394 )
395 .await
396 .unwrap();
397
398 assert_eq!(resp.status(), StatusCode::OK);
399 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
400 .await
401 .unwrap();
402 let del_resp: DeleteResponse = serde_json::from_slice(&body).unwrap();
403 assert_eq!(del_resp.deleted, 1);
404
405 let app3 = router(store);
407 let resp = app3
408 .oneshot(
409 Request::builder()
410 .uri("/v1/status")
411 .body(Body::empty())
412 .unwrap(),
413 )
414 .await
415 .unwrap();
416
417 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
418 .await
419 .unwrap();
420 let status: StatusResponse = serde_json::from_slice(&body).unwrap();
421 assert_eq!(status.total_vectors, 2);
422 }
423
424 #[tokio::test]
425 async fn test_ingest_bad_request() {
426 let (_dir, store) = create_test_store();
427 let app = router(store);
428
429 let body = serde_json::json!({
431 "vectors": [[1.0, 0.0, 0.0, 0.0]],
432 "ids": [1, 2]
433 });
434
435 let resp = app
436 .oneshot(
437 Request::builder()
438 .method("POST")
439 .uri("/v1/ingest")
440 .header("content-type", "application/json")
441 .body(Body::from(serde_json::to_vec(&body).unwrap()))
442 .unwrap(),
443 )
444 .await
445 .unwrap();
446
447 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
448 }
449
450 #[tokio::test]
451 async fn test_query_bad_k() {
452 let (_dir, store) = create_test_store();
453 let app = router(store);
454
455 let body = serde_json::json!({
456 "vector": [1.0, 0.0, 0.0, 0.0],
457 "k": 0
458 });
459
460 let resp = app
461 .oneshot(
462 Request::builder()
463 .method("POST")
464 .uri("/v1/query")
465 .header("content-type", "application/json")
466 .body(Body::from(serde_json::to_vec(&body).unwrap()))
467 .unwrap(),
468 )
469 .await
470 .unwrap();
471
472 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
473 }
474
475 #[tokio::test]
476 async fn test_delete_empty_ids() {
477 let (_dir, store) = create_test_store();
478 let app = router(store);
479
480 let body = serde_json::json!({ "ids": [] });
481
482 let resp = app
483 .oneshot(
484 Request::builder()
485 .method("POST")
486 .uri("/v1/delete")
487 .header("content-type", "application/json")
488 .body(Body::from(serde_json::to_vec(&body).unwrap()))
489 .unwrap(),
490 )
491 .await
492 .unwrap();
493
494 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
495 }
496}