vectradb_api/
lib.rs

1use axum::{
2    extract::{Path, State},
3    http::StatusCode,
4    response::Json,
5    routing::{delete, get, post, put},
6    Router,
7};
8use ndarray::Array1;
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::sync::RwLock;
13use tower_http::cors::CorsLayer;
14use vectradb_components::{DatabaseStats, SimilarityResult, VectorDatabase, VectraDBError};
15use vectradb_storage::{DatabaseConfig, PersistentVectorDB};
16
17/// API server state
18#[derive(Clone)]
19pub struct AppState {
20    pub db: Arc<RwLock<PersistentVectorDB>>,
21}
22
23/// Request/Response types for API endpoints
24
25#[derive(Debug, Deserialize)]
26pub struct CreateVectorRequest {
27    pub id: String,
28    pub vector: Vec<f32>,
29    pub tags: Option<HashMap<String, String>>,
30}
31
32#[derive(Debug, Deserialize)]
33pub struct UpdateVectorRequest {
34    pub vector: Vec<f32>,
35    pub tags: Option<HashMap<String, String>>,
36}
37
38#[derive(Debug, Deserialize)]
39pub struct UpsertVectorRequest {
40    pub vector: Vec<f32>,
41    pub tags: Option<HashMap<String, String>>,
42}
43
44#[derive(Debug, Deserialize)]
45pub struct SearchRequest {
46    pub vector: Vec<f32>,
47    pub top_k: Option<usize>,
48}
49
50#[derive(Debug, Serialize)]
51pub struct VectorResponse {
52    pub id: String,
53    pub vector: Vec<f32>,
54    pub dimension: usize,
55    pub created_at: u64,
56    pub updated_at: u64,
57    pub tags: HashMap<String, String>,
58}
59
60#[derive(Debug, Serialize)]
61pub struct SearchResponse {
62    pub results: Vec<SimilarityResult>,
63    pub total_time_ms: f64,
64}
65
66#[derive(Debug, Serialize)]
67pub struct ErrorResponse {
68    pub error: String,
69    pub message: String,
70}
71
72/// Create the API router
73pub fn create_router(state: AppState) -> Router {
74    Router::new()
75        .route("/health", get(health_check))
76        .route("/stats", get(get_stats))
77        .route("/vectors", post(create_vector))
78        .route("/vectors/:id", get(get_vector))
79        .route("/vectors/:id", put(update_vector))
80        .route("/vectors/:id", delete(delete_vector))
81        .route("/vectors/:id/upsert", put(upsert_vector))
82        .route("/search", post(search_vectors))
83        .route("/vectors", get(list_vectors))
84        .layer(CorsLayer::permissive())
85        .with_state(state)
86}
87
88/// Health check endpoint
89async fn health_check() -> Result<Json<HashMap<String, String>>, StatusCode> {
90    let mut response = HashMap::new();
91    response.insert("status".to_string(), "healthy".to_string());
92    response.insert("service".to_string(), "vectradb-api".to_string());
93    Ok(Json(response))
94}
95
96/// Get database statistics
97async fn get_stats(
98    State(state): State<AppState>,
99) -> Result<Json<DatabaseStats>, (StatusCode, Json<ErrorResponse>)> {
100    let db = state.db.read().await;
101    match db.get_stats() {
102        Ok(stats) => Ok(Json(stats)),
103        Err(e) => Err((
104            StatusCode::INTERNAL_SERVER_ERROR,
105            Json(ErrorResponse {
106                error: "Database error".to_string(),
107                message: e.to_string(),
108            }),
109        )),
110    }
111}
112
113/// Create a new vector
114async fn create_vector(
115    State(state): State<AppState>,
116    Json(request): Json<CreateVectorRequest>,
117) -> Result<Json<VectorResponse>, (StatusCode, Json<ErrorResponse>)> {
118    let vector = Array1::from_vec(request.vector);
119
120    let mut db = state.db.write().await;
121    match db.create_vector(request.id.clone(), vector, request.tags) {
122        Ok(_) => {
123            // Fetch the created vector to return complete information
124            match db.get_vector(&request.id) {
125                Ok(document) => Ok(Json(VectorResponse {
126                    id: document.metadata.id,
127                    vector: document.data.to_vec(),
128                    dimension: document.metadata.dimension,
129                    created_at: document.metadata.created_at,
130                    updated_at: document.metadata.updated_at,
131                    tags: document.metadata.tags,
132                })),
133                Err(e) => Err((
134                    StatusCode::INTERNAL_SERVER_ERROR,
135                    Json(ErrorResponse {
136                        error: "Failed to fetch created vector".to_string(),
137                        message: e.to_string(),
138                    }),
139                )),
140            }
141        }
142        Err(e) => Err((
143            StatusCode::BAD_REQUEST,
144            Json(ErrorResponse {
145                error: "Failed to create vector".to_string(),
146                message: e.to_string(),
147            }),
148        )),
149    }
150}
151
152/// Get a vector by ID
153async fn get_vector(
154    State(state): State<AppState>,
155    Path(id): Path<String>,
156) -> Result<Json<VectorResponse>, (StatusCode, Json<ErrorResponse>)> {
157    let db = state.db.read().await;
158    match db.get_vector(&id) {
159        Ok(document) => Ok(Json(VectorResponse {
160            id: document.metadata.id,
161            vector: document.data.to_vec(),
162            dimension: document.metadata.dimension,
163            created_at: document.metadata.created_at,
164            updated_at: document.metadata.updated_at,
165            tags: document.metadata.tags,
166        })),
167        Err(VectraDBError::VectorNotFound { .. }) => Err((
168            StatusCode::NOT_FOUND,
169            Json(ErrorResponse {
170                error: "Vector not found".to_string(),
171                message: format!("Vector with ID '{}' not found", id),
172            }),
173        )),
174        Err(e) => Err((
175            StatusCode::INTERNAL_SERVER_ERROR,
176            Json(ErrorResponse {
177                error: "Database error".to_string(),
178                message: e.to_string(),
179            }),
180        )),
181    }
182}
183
184/// Update an existing vector
185async fn update_vector(
186    State(state): State<AppState>,
187    Path(id): Path<String>,
188    Json(request): Json<UpdateVectorRequest>,
189) -> Result<Json<VectorResponse>, (StatusCode, Json<ErrorResponse>)> {
190    let vector = Array1::from_vec(request.vector);
191
192    let mut db = state.db.write().await;
193    match db.update_vector(&id, vector, request.tags) {
194        Ok(_) => {
195            // Fetch the updated vector to return complete information
196            match db.get_vector(&id) {
197                Ok(document) => Ok(Json(VectorResponse {
198                    id: document.metadata.id,
199                    vector: document.data.to_vec(),
200                    dimension: document.metadata.dimension,
201                    created_at: document.metadata.created_at,
202                    updated_at: document.metadata.updated_at,
203                    tags: document.metadata.tags,
204                })),
205                Err(e) => Err((
206                    StatusCode::INTERNAL_SERVER_ERROR,
207                    Json(ErrorResponse {
208                        error: "Failed to fetch updated vector".to_string(),
209                        message: e.to_string(),
210                    }),
211                )),
212            }
213        }
214        Err(VectraDBError::VectorNotFound { .. }) => Err((
215            StatusCode::NOT_FOUND,
216            Json(ErrorResponse {
217                error: "Vector not found".to_string(),
218                message: format!("Vector with ID '{}' not found", id),
219            }),
220        )),
221        Err(e) => Err((
222            StatusCode::BAD_REQUEST,
223            Json(ErrorResponse {
224                error: "Failed to update vector".to_string(),
225                message: e.to_string(),
226            }),
227        )),
228    }
229}
230
231/// Delete a vector by ID
232async fn delete_vector(
233    State(state): State<AppState>,
234    Path(id): Path<String>,
235) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
236    let mut db = state.db.write().await;
237    match db.delete_vector(&id) {
238        Ok(_) => Ok(StatusCode::NO_CONTENT),
239        Err(VectraDBError::VectorNotFound { .. }) => Err((
240            StatusCode::NOT_FOUND,
241            Json(ErrorResponse {
242                error: "Vector not found".to_string(),
243                message: format!("Vector with ID '{}' not found", id),
244            }),
245        )),
246        Err(e) => Err((
247            StatusCode::INTERNAL_SERVER_ERROR,
248            Json(ErrorResponse {
249                error: "Database error".to_string(),
250                message: e.to_string(),
251            }),
252        )),
253    }
254}
255
256/// Upsert a vector (insert or update)
257async fn upsert_vector(
258    State(state): State<AppState>,
259    Path(id): Path<String>,
260    Json(request): Json<UpsertVectorRequest>,
261) -> Result<Json<VectorResponse>, (StatusCode, Json<ErrorResponse>)> {
262    let vector = Array1::from_vec(request.vector);
263
264    let mut db = state.db.write().await;
265    match db.upsert_vector(id.clone(), vector, request.tags) {
266        Ok(_) => {
267            // Fetch the upserted vector to return complete information
268            match db.get_vector(&id) {
269                Ok(document) => Ok(Json(VectorResponse {
270                    id: document.metadata.id,
271                    vector: document.data.to_vec(),
272                    dimension: document.metadata.dimension,
273                    created_at: document.metadata.created_at,
274                    updated_at: document.metadata.updated_at,
275                    tags: document.metadata.tags,
276                })),
277                Err(e) => Err((
278                    StatusCode::INTERNAL_SERVER_ERROR,
279                    Json(ErrorResponse {
280                        error: "Failed to fetch upserted vector".to_string(),
281                        message: e.to_string(),
282                    }),
283                )),
284            }
285        }
286        Err(e) => Err((
287            StatusCode::BAD_REQUEST,
288            Json(ErrorResponse {
289                error: "Failed to upsert vector".to_string(),
290                message: e.to_string(),
291            }),
292        )),
293    }
294}
295
296/// Search for similar vectors
297async fn search_vectors(
298    State(state): State<AppState>,
299    Json(request): Json<SearchRequest>,
300) -> Result<Json<SearchResponse>, (StatusCode, Json<ErrorResponse>)> {
301    let vector = Array1::from_vec(request.vector);
302    let top_k = request.top_k.unwrap_or(10);
303
304    let start_time = std::time::Instant::now();
305    let db = state.db.read().await;
306
307    match db.search_similar(vector, top_k) {
308        Ok(results) => {
309            let total_time = start_time.elapsed().as_secs_f64() * 1000.0; // Convert to milliseconds
310            Ok(Json(SearchResponse {
311                results,
312                total_time_ms: total_time,
313            }))
314        }
315        Err(e) => Err((
316            StatusCode::BAD_REQUEST,
317            Json(ErrorResponse {
318                error: "Search failed".to_string(),
319                message: e.to_string(),
320            }),
321        )),
322    }
323}
324
325/// List all vector IDs
326async fn list_vectors(
327    State(state): State<AppState>,
328) -> Result<Json<Vec<String>>, (StatusCode, Json<ErrorResponse>)> {
329    let db = state.db.read().await;
330    match db.list_vectors() {
331        Ok(ids) => Ok(Json(ids)),
332        Err(e) => Err((
333            StatusCode::INTERNAL_SERVER_ERROR,
334            Json(ErrorResponse {
335                error: "Database error".to_string(),
336                message: e.to_string(),
337            }),
338        )),
339    }
340}
341
342/// Start the API server
343pub async fn start_server(
344    config: DatabaseConfig,
345    port: u16,
346) -> Result<(), Box<dyn std::error::Error>> {
347    // Initialize database
348    let db = PersistentVectorDB::new(config).await?;
349    let state = AppState {
350        db: Arc::new(RwLock::new(db)),
351    };
352
353    // Create router
354    let app = create_router(state);
355
356    // Start server
357    let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{}", port)).await?;
358    println!("VectraDB API server running on http://0.0.0.0:{}", port);
359
360    axum::serve(listener, app).await?;
361    Ok(())
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[tokio::test]
369    async fn test_create_vector_request() {
370        let request = CreateVectorRequest {
371            id: "test_id".to_string(),
372            vector: vec![1.0, 2.0, 3.0],
373            tags: Some(HashMap::from([(
374                "category".to_string(),
375                "test".to_string(),
376            )])),
377        };
378
379        assert_eq!(request.id, "test_id");
380        assert_eq!(request.vector.len(), 3);
381    }
382
383    #[tokio::test]
384    async fn test_search_request() {
385        let request = SearchRequest {
386            vector: vec![1.0, 2.0, 3.0],
387            top_k: Some(5),
388        };
389
390        assert_eq!(request.vector.len(), 3);
391        assert_eq!(request.top_k, Some(5));
392    }
393}