ruvector_server/routes/
points.rs1use crate::{error::Error, state::AppState, Result};
4use axum::{
5 extract::{Path, State},
6 http::StatusCode,
7 response::IntoResponse,
8 routing::{get, post, put},
9 Json, Router,
10};
11use ruvector_core::{SearchQuery, SearchResult, VectorEntry};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15#[derive(Debug, Deserialize)]
17pub struct UpsertPointsRequest {
18 pub points: Vec<VectorEntry>,
20}
21
22#[derive(Debug, Deserialize)]
24pub struct SearchRequest {
25 pub vector: Vec<f32>,
27 #[serde(default = "default_limit")]
29 pub k: usize,
30 pub score_threshold: Option<f32>,
32 pub filter: Option<HashMap<String, serde_json::Value>>,
34}
35
36fn default_limit() -> usize {
37 10
38}
39
40#[derive(Debug, Serialize)]
42pub struct SearchResponse {
43 pub results: Vec<SearchResult>,
45}
46
47#[derive(Debug, Serialize)]
49pub struct UpsertResponse {
50 pub ids: Vec<String>,
52}
53
54pub fn routes() -> Router<AppState> {
56 Router::new()
57 .route("/collections/:name/points", put(upsert_points))
58 .route("/collections/:name/points/search", post(search_points))
59 .route("/collections/:name/points/:id", get(get_point))
60}
61
62async fn upsert_points(
66 State(state): State<AppState>,
67 Path(name): Path<String>,
68 Json(req): Json<UpsertPointsRequest>,
69) -> Result<impl IntoResponse> {
70 let db = state
71 .get_collection(&name)
72 .ok_or_else(|| Error::CollectionNotFound(name.clone()))?;
73
74 let ids = db.insert_batch(req.points).map_err(Error::Core)?;
75
76 Ok((StatusCode::OK, Json(UpsertResponse { ids })))
77}
78
79async fn search_points(
83 State(state): State<AppState>,
84 Path(name): Path<String>,
85 Json(req): Json<SearchRequest>,
86) -> Result<impl IntoResponse> {
87 let db = state
88 .get_collection(&name)
89 .ok_or_else(|| Error::CollectionNotFound(name))?;
90
91 let query = SearchQuery {
92 vector: req.vector,
93 k: req.k,
94 filter: req.filter,
95 ef_search: None,
96 };
97
98 let mut results = db.search(query).map_err(Error::Core)?;
99
100 if let Some(threshold) = req.score_threshold {
102 results.retain(|r| r.score >= threshold);
103 }
104
105 Ok(Json(SearchResponse { results }))
106}
107
108async fn get_point(
112 State(state): State<AppState>,
113 Path((name, id)): Path<(String, String)>,
114) -> Result<impl IntoResponse> {
115 let db = state
116 .get_collection(&name)
117 .ok_or_else(|| Error::CollectionNotFound(name))?;
118
119 let entry = db.get(&id).map_err(Error::Core)?;
120
121 Ok(Json(entry))
122}