1use axum::{
2 extract::{Path, State},
3 http::StatusCode,
4 response::Json,
5 routing::{delete, get, post, put},
6 Router,
7};
8use ndarray::Array1;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use tower_http::cors::CorsLayer;
14use vectradb_components::{DatabaseStats, SimilarityResult, VectorDatabase, VectraDBError};
15use vectradb_storage::{DatabaseConfig, PersistentVectorDB};
16
17#[derive(Clone)]
19pub struct AppState {
20 pub db: Arc<RwLock<PersistentVectorDB>>,
21}
22
23#[derive(Debug, Deserialize)]
26pub struct CreateVectorRequest {
27 pub id: String,
28 pub vector: Vec<f32>,
29 pub tags: Option<HashMap<String, String>>,
30}
31
32#[derive(Debug, Deserialize)]
33pub struct UpdateVectorRequest {
34 pub vector: Vec<f32>,
35 pub tags: Option<HashMap<String, String>>,
36}
37
38#[derive(Debug, Deserialize)]
39pub struct UpsertVectorRequest {
40 pub vector: Vec<f32>,
41 pub tags: Option<HashMap<String, String>>,
42}
43
44#[derive(Debug, Deserialize)]
45pub struct SearchRequest {
46 pub vector: Vec<f32>,
47 pub top_k: Option<usize>,
48}
49
50#[derive(Debug, Serialize)]
51pub struct VectorResponse {
52 pub id: String,
53 pub vector: Vec<f32>,
54 pub dimension: usize,
55 pub created_at: u64,
56 pub updated_at: u64,
57 pub tags: HashMap<String, String>,
58}
59
60#[derive(Debug, Serialize)]
61pub struct SearchResponse {
62 pub results: Vec<SimilarityResult>,
63 pub total_time_ms: f64,
64}
65
66#[derive(Debug, Serialize)]
67pub struct ErrorResponse {
68 pub error: String,
69 pub message: String,
70}
71
72pub fn create_router(state: AppState) -> Router {
74 Router::new()
75 .route("/health", get(health_check))
76 .route("/stats", get(get_stats))
77 .route("/vectors", post(create_vector))
78 .route("/vectors/:id", get(get_vector))
79 .route("/vectors/:id", put(update_vector))
80 .route("/vectors/:id", delete(delete_vector))
81 .route("/vectors/:id/upsert", put(upsert_vector))
82 .route("/search", post(search_vectors))
83 .route("/vectors", get(list_vectors))
84 .layer(CorsLayer::permissive())
85 .with_state(state)
86}
87
88async fn health_check() -> Result<Json<HashMap<String, String>>, StatusCode> {
90 let mut response = HashMap::new();
91 response.insert("status".to_string(), "healthy".to_string());
92 response.insert("service".to_string(), "vectradb-api".to_string());
93 Ok(Json(response))
94}
95
96async fn get_stats(
98 State(state): State<AppState>,
99) -> Result<Json<DatabaseStats>, (StatusCode, Json<ErrorResponse>)> {
100 let db = state.db.read().await;
101 match db.get_stats() {
102 Ok(stats) => Ok(Json(stats)),
103 Err(e) => Err((
104 StatusCode::INTERNAL_SERVER_ERROR,
105 Json(ErrorResponse {
106 error: "Database error".to_string(),
107 message: e.to_string(),
108 }),
109 )),
110 }
111}
112
113async fn create_vector(
115 State(state): State<AppState>,
116 Json(request): Json<CreateVectorRequest>,
117) -> Result<Json<VectorResponse>, (StatusCode, Json<ErrorResponse>)> {
118 let vector = Array1::from_vec(request.vector);
119
120 let mut db = state.db.write().await;
121 match db.create_vector(request.id.clone(), vector, request.tags) {
122 Ok(_) => {
123 match db.get_vector(&request.id) {
125 Ok(document) => Ok(Json(VectorResponse {
126 id: document.metadata.id,
127 vector: document.data.to_vec(),
128 dimension: document.metadata.dimension,
129 created_at: document.metadata.created_at,
130 updated_at: document.metadata.updated_at,
131 tags: document.metadata.tags,
132 })),
133 Err(e) => Err((
134 StatusCode::INTERNAL_SERVER_ERROR,
135 Json(ErrorResponse {
136 error: "Failed to fetch created vector".to_string(),
137 message: e.to_string(),
138 }),
139 )),
140 }
141 }
142 Err(e) => Err((
143 StatusCode::BAD_REQUEST,
144 Json(ErrorResponse {
145 error: "Failed to create vector".to_string(),
146 message: e.to_string(),
147 }),
148 )),
149 }
150}
151
152async fn get_vector(
154 State(state): State<AppState>,
155 Path(id): Path<String>,
156) -> Result<Json<VectorResponse>, (StatusCode, Json<ErrorResponse>)> {
157 let db = state.db.read().await;
158 match db.get_vector(&id) {
159 Ok(document) => Ok(Json(VectorResponse {
160 id: document.metadata.id,
161 vector: document.data.to_vec(),
162 dimension: document.metadata.dimension,
163 created_at: document.metadata.created_at,
164 updated_at: document.metadata.updated_at,
165 tags: document.metadata.tags,
166 })),
167 Err(VectraDBError::VectorNotFound { .. }) => Err((
168 StatusCode::NOT_FOUND,
169 Json(ErrorResponse {
170 error: "Vector not found".to_string(),
171 message: format!("Vector with ID '{}' not found", id),
172 }),
173 )),
174 Err(e) => Err((
175 StatusCode::INTERNAL_SERVER_ERROR,
176 Json(ErrorResponse {
177 error: "Database error".to_string(),
178 message: e.to_string(),
179 }),
180 )),
181 }
182}
183
184async fn update_vector(
186 State(state): State<AppState>,
187 Path(id): Path<String>,
188 Json(request): Json<UpdateVectorRequest>,
189) -> Result<Json<VectorResponse>, (StatusCode, Json<ErrorResponse>)> {
190 let vector = Array1::from_vec(request.vector);
191
192 let mut db = state.db.write().await;
193 match db.update_vector(&id, vector, request.tags) {
194 Ok(_) => {
195 match db.get_vector(&id) {
197 Ok(document) => Ok(Json(VectorResponse {
198 id: document.metadata.id,
199 vector: document.data.to_vec(),
200 dimension: document.metadata.dimension,
201 created_at: document.metadata.created_at,
202 updated_at: document.metadata.updated_at,
203 tags: document.metadata.tags,
204 })),
205 Err(e) => Err((
206 StatusCode::INTERNAL_SERVER_ERROR,
207 Json(ErrorResponse {
208 error: "Failed to fetch updated vector".to_string(),
209 message: e.to_string(),
210 }),
211 )),
212 }
213 }
214 Err(VectraDBError::VectorNotFound { .. }) => Err((
215 StatusCode::NOT_FOUND,
216 Json(ErrorResponse {
217 error: "Vector not found".to_string(),
218 message: format!("Vector with ID '{}' not found", id),
219 }),
220 )),
221 Err(e) => Err((
222 StatusCode::BAD_REQUEST,
223 Json(ErrorResponse {
224 error: "Failed to update vector".to_string(),
225 message: e.to_string(),
226 }),
227 )),
228 }
229}
230
231async fn delete_vector(
233 State(state): State<AppState>,
234 Path(id): Path<String>,
235) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
236 let mut db = state.db.write().await;
237 match db.delete_vector(&id) {
238 Ok(_) => Ok(StatusCode::NO_CONTENT),
239 Err(VectraDBError::VectorNotFound { .. }) => Err((
240 StatusCode::NOT_FOUND,
241 Json(ErrorResponse {
242 error: "Vector not found".to_string(),
243 message: format!("Vector with ID '{}' not found", id),
244 }),
245 )),
246 Err(e) => Err((
247 StatusCode::INTERNAL_SERVER_ERROR,
248 Json(ErrorResponse {
249 error: "Database error".to_string(),
250 message: e.to_string(),
251 }),
252 )),
253 }
254}
255
256async fn upsert_vector(
258 State(state): State<AppState>,
259 Path(id): Path<String>,
260 Json(request): Json<UpsertVectorRequest>,
261) -> Result<Json<VectorResponse>, (StatusCode, Json<ErrorResponse>)> {
262 let vector = Array1::from_vec(request.vector);
263
264 let mut db = state.db.write().await;
265 match db.upsert_vector(id.clone(), vector, request.tags) {
266 Ok(_) => {
267 match db.get_vector(&id) {
269 Ok(document) => Ok(Json(VectorResponse {
270 id: document.metadata.id,
271 vector: document.data.to_vec(),
272 dimension: document.metadata.dimension,
273 created_at: document.metadata.created_at,
274 updated_at: document.metadata.updated_at,
275 tags: document.metadata.tags,
276 })),
277 Err(e) => Err((
278 StatusCode::INTERNAL_SERVER_ERROR,
279 Json(ErrorResponse {
280 error: "Failed to fetch upserted vector".to_string(),
281 message: e.to_string(),
282 }),
283 )),
284 }
285 }
286 Err(e) => Err((
287 StatusCode::BAD_REQUEST,
288 Json(ErrorResponse {
289 error: "Failed to upsert vector".to_string(),
290 message: e.to_string(),
291 }),
292 )),
293 }
294}
295
296async fn search_vectors(
298 State(state): State<AppState>,
299 Json(request): Json<SearchRequest>,
300) -> Result<Json<SearchResponse>, (StatusCode, Json<ErrorResponse>)> {
301 let vector = Array1::from_vec(request.vector);
302 let top_k = request.top_k.unwrap_or(10);
303
304 let start_time = std::time::Instant::now();
305 let db = state.db.read().await;
306
307 match db.search_similar(vector, top_k) {
308 Ok(results) => {
309 let total_time = start_time.elapsed().as_secs_f64() * 1000.0; Ok(Json(SearchResponse {
311 results,
312 total_time_ms: total_time,
313 }))
314 }
315 Err(e) => Err((
316 StatusCode::BAD_REQUEST,
317 Json(ErrorResponse {
318 error: "Search failed".to_string(),
319 message: e.to_string(),
320 }),
321 )),
322 }
323}
324
325async fn list_vectors(
327 State(state): State<AppState>,
328) -> Result<Json<Vec<String>>, (StatusCode, Json<ErrorResponse>)> {
329 let db = state.db.read().await;
330 match db.list_vectors() {
331 Ok(ids) => Ok(Json(ids)),
332 Err(e) => Err((
333 StatusCode::INTERNAL_SERVER_ERROR,
334 Json(ErrorResponse {
335 error: "Database error".to_string(),
336 message: e.to_string(),
337 }),
338 )),
339 }
340}
341
342pub async fn start_server(
344 config: DatabaseConfig,
345 port: u16,
346) -> Result<(), Box<dyn std::error::Error>> {
347 let db = PersistentVectorDB::new(config).await?;
349 let state = AppState {
350 db: Arc::new(RwLock::new(db)),
351 };
352
353 let app = create_router(state);
355
356 let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", port)).await?;
358 println!("VectraDB API server running on http://0.0.0.0:{}", port);
359
360 axum::serve(listener, app).await?;
361 Ok(())
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[tokio::test]
369 async fn test_create_vector_request() {
370 let request = CreateVectorRequest {
371 id: "test_id".to_string(),
372 vector: vec![1.0, 2.0, 3.0],
373 tags: Some(HashMap::from([(
374 "category".to_string(),
375 "test".to_string(),
376 )])),
377 };
378
379 assert_eq!(request.id, "test_id");
380 assert_eq!(request.vector.len(), 3);
381 }
382
383 #[tokio::test]
384 async fn test_search_request() {
385 let request = SearchRequest {
386 vector: vec![1.0, 2.0, 3.0],
387 top_k: Some(5),
388 };
389
390 assert_eq!(request.vector.len(), 3);
391 assert_eq!(request.top_k, Some(5));
392 }
393}