Skip to main content

sochdb_grpc/
server.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! gRPC Server implementation for Vector Index Service
19
20use crate::error::GrpcError;
21use crate::proto::{
22    self,
23    vector_index_service_server::{VectorIndexService, VectorIndexServiceServer},
24    CreateIndexRequest, CreateIndexResponse, DropIndexRequest, DropIndexResponse,
25    GetStatsRequest, GetStatsResponse, HealthCheckRequest, HealthCheckResponse,
26    HnswConfig as ProtoHnswConfig, IndexInfo, IndexStats, InsertBatchRequest,
27    InsertBatchResponse, InsertStreamRequest, InsertStreamResponse, QueryResults,
28    SearchBatchRequest, SearchBatchResponse, SearchRequest, SearchResponse, SearchResult,
29};
30use dashmap::DashMap;
31use std::sync::Arc;
32use std::time::Instant;
33use tokio_stream::StreamExt;
34use tonic::{Request, Response, Status, Streaming};
35use sochdb_index::hnsw::{DistanceMetric, HnswConfig, HnswIndex};
36
37/// Server version
38const VERSION: &str = env!("CARGO_PKG_VERSION");
39
40/// Metadata for an index
41#[allow(dead_code)]
42struct IndexEntry {
43    index: Arc<HnswIndex>,
44    name: String,
45    dimension: usize,
46    metric: proto::DistanceMetric,
47    config: ProtoHnswConfig,
48    created_at: u64,
49}
50
51/// Vector Index gRPC Server
52pub struct VectorIndexServer {
53    /// Map of index name -> index entry
54    indexes: DashMap<String, IndexEntry>,
55}
56
57impl VectorIndexServer {
58    /// Create a new server instance
59    pub fn new() -> Self {
60        Self {
61            indexes: DashMap::new(),
62        }
63    }
64    
65    /// Create the gRPC service
66    pub fn into_service(self) -> VectorIndexServiceServer<Self> {
67        VectorIndexServiceServer::new(self)
68    }
69    
70    /// Get an index by name and its dimension
71    fn get_index_with_dim(&self, name: &str) -> Result<(Arc<HnswIndex>, usize), GrpcError> {
72        self.indexes
73            .get(name)
74            .map(|entry| (entry.index.clone(), entry.dimension))
75            .ok_or_else(|| GrpcError::IndexNotFound(name.to_string()))
76    }
77    
78    /// Get an index by name
79    fn get_index(&self, name: &str) -> Result<Arc<HnswIndex>, GrpcError> {
80        self.indexes
81            .get(name)
82            .map(|entry| entry.index.clone())
83            .ok_or_else(|| GrpcError::IndexNotFound(name.to_string()))
84    }
85    
86    /// Convert proto metric to internal metric
87    fn convert_metric(metric: proto::DistanceMetric) -> DistanceMetric {
88        match metric {
89            proto::DistanceMetric::L2 => DistanceMetric::Euclidean,
90            proto::DistanceMetric::Cosine => DistanceMetric::Cosine,
91            proto::DistanceMetric::DotProduct => DistanceMetric::DotProduct,
92            _ => DistanceMetric::Cosine, // Default
93        }
94    }
95}
96
97impl Default for VectorIndexServer {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103#[tonic::async_trait]
104impl VectorIndexService for VectorIndexServer {
105    async fn create_index(
106        &self,
107        request: Request<CreateIndexRequest>,
108    ) -> Result<Response<CreateIndexResponse>, Status> {
109        let req = request.into_inner();
110        let name = req.name.clone();
111        
112        // Check if index already exists
113        if self.indexes.contains_key(&name) {
114            return Ok(Response::new(CreateIndexResponse {
115                success: false,
116                error: format!("Index '{}' already exists", name),
117                info: None,
118            }));
119        }
120        
121        // Build config
122        let proto_config = req.config.unwrap_or_default();
123        let config = HnswConfig {
124            max_connections: if proto_config.max_connections > 0 {
125                proto_config.max_connections as usize
126            } else {
127                16
128            },
129            max_connections_layer0: if proto_config.max_connections_layer0 > 0 {
130                proto_config.max_connections_layer0 as usize
131            } else {
132                32
133            },
134            ef_construction: if proto_config.ef_construction > 0 {
135                proto_config.ef_construction as usize
136            } else {
137                200
138            },
139            ef_search: if proto_config.ef_search > 0 {
140                proto_config.ef_search as usize
141            } else {
142                50
143            },
144            metric: Self::convert_metric(req.metric()),
145            ..Default::default()
146        };
147        
148        let dimension = req.dimension as usize;
149        let index = HnswIndex::new(dimension, config.clone());
150        let created_at = std::time::SystemTime::now()
151            .duration_since(std::time::UNIX_EPOCH)
152            .unwrap()
153            .as_secs();
154        
155        let entry = IndexEntry {
156            index: Arc::new(index),
157            name: name.clone(),
158            dimension,
159            metric: req.metric(),
160            config: proto_config.clone(),
161            created_at,
162        };
163        
164        self.indexes.insert(name.clone(), entry);
165        
166        tracing::info!("Created index '{}' with dimension {}", name, dimension);
167        
168        Ok(Response::new(CreateIndexResponse {
169            success: true,
170            error: String::new(),
171            info: Some(IndexInfo {
172                name,
173                dimension: dimension as u32,
174                metric: req.metric.into(),
175                config: Some(proto_config),
176                created_at,
177            }),
178        }))
179    }
180    
181    async fn drop_index(
182        &self,
183        request: Request<DropIndexRequest>,
184    ) -> Result<Response<DropIndexResponse>, Status> {
185        let name = request.into_inner().name;
186        
187        match self.indexes.remove(&name) {
188            Some(_) => {
189                tracing::info!("Dropped index '{}'", name);
190                Ok(Response::new(DropIndexResponse {
191                    success: true,
192                    error: String::new(),
193                }))
194            }
195            None => Ok(Response::new(DropIndexResponse {
196                success: false,
197                error: format!("Index '{}' not found", name),
198            })),
199        }
200    }
201    
202    async fn insert_batch(
203        &self,
204        request: Request<InsertBatchRequest>,
205    ) -> Result<Response<InsertBatchResponse>, Status> {
206        let start = Instant::now();
207        let req = request.into_inner();
208        
209        let (index, dimension) = self.get_index_with_dim(&req.index_name)?;
210        
211        // Validate input
212        if req.vectors.len() != req.ids.len() * dimension {
213            return Err(Status::invalid_argument(format!(
214                "Vector data size mismatch: expected {} floats, got {}",
215                req.ids.len() * dimension,
216                req.vectors.len()
217            )));
218        }
219        
220        // Convert IDs to u128
221        let ids: Vec<u128> = req.ids.iter().map(|&id| id as u128).collect();
222        
223        // Use flat batch insert for zero-copy performance
224        match index.insert_batch_flat(&ids, &req.vectors, dimension) {
225            Ok(count) => {
226                let duration_us = start.elapsed().as_micros() as u64;
227                tracing::debug!(
228                    "Inserted {} vectors into '{}' in {}µs",
229                    count,
230                    req.index_name,
231                    duration_us
232                );
233                Ok(Response::new(InsertBatchResponse {
234                    inserted_count: count as u32,
235                    error: String::new(),
236                    duration_us,
237                }))
238            }
239            Err(e) => Ok(Response::new(InsertBatchResponse {
240                inserted_count: 0,
241                error: e,
242                duration_us: start.elapsed().as_micros() as u64,
243            })),
244        }
245    }
246    
247    async fn insert_stream(
248        &self,
249        request: Request<Streaming<InsertStreamRequest>>,
250    ) -> Result<Response<InsertStreamResponse>, Status> {
251        let start = Instant::now();
252        let mut stream = request.into_inner();
253        
254        let mut index_name: Option<String> = None;
255        let mut index: Option<Arc<HnswIndex>> = None;
256        let mut total_inserted = 0u32;
257        let mut errors = Vec::new();
258        
259        while let Some(result) = stream.next().await {
260            match result {
261                Ok(req) => {
262                    // Get index on first message
263                    if index.is_none() {
264                        if req.index_name.is_empty() {
265                            errors.push("First message must include index_name".to_string());
266                            continue;
267                        }
268                        index_name = Some(req.index_name.clone());
269                        match self.get_index(&req.index_name) {
270                            Ok(idx) => index = Some(idx),
271                            Err(e) => {
272                                errors.push(e.to_string());
273                                break;
274                            }
275                        }
276                    }
277                    
278                    // Insert the vector
279                    if let Some(ref idx) = index {
280                        let vector: Vec<f32> = req.vector;
281                        match idx.insert_one_from_slice(req.id as u128, &vector) {
282                            Ok(()) => total_inserted += 1,
283                            Err(e) => errors.push(format!("ID {}: {}", req.id, e)),
284                        }
285                    }
286                }
287                Err(e) => {
288                    errors.push(format!("Stream error: {}", e));
289                    break;
290                }
291            }
292        }
293        
294        let duration_us = start.elapsed().as_micros() as u64;
295        
296        if let Some(name) = &index_name {
297            tracing::debug!(
298                "Stream inserted {} vectors into '{}' in {}µs",
299                total_inserted,
300                name,
301                duration_us
302            );
303        }
304        
305        Ok(Response::new(InsertStreamResponse {
306            total_inserted,
307            errors,
308            duration_us,
309        }))
310    }
311    
312    async fn search(
313        &self,
314        request: Request<SearchRequest>,
315    ) -> Result<Response<SearchResponse>, Status> {
316        let start = Instant::now();
317        let req = request.into_inner();
318        
319        let (index, dimension) = self.get_index_with_dim(&req.index_name)?;
320        
321        // Validate dimension
322        if req.query.len() != dimension {
323            return Err(Status::invalid_argument(format!(
324                "Query dimension mismatch: expected {}, got {}",
325                dimension,
326                req.query.len()
327            )));
328        }
329        
330        let k = req.k.max(1) as usize;
331        
332        // Perform search
333        let results = match index.search(&req.query, k) {
334            Ok(r) => r,
335            Err(e) => {
336                return Ok(Response::new(SearchResponse {
337                    results: vec![],
338                    duration_us: start.elapsed().as_micros() as u64,
339                    error: e,
340                }));
341            }
342        };
343        
344        let duration_us = start.elapsed().as_micros() as u64;
345        
346        Ok(Response::new(SearchResponse {
347            results: results
348                .into_iter()
349                .map(|(id, distance)| SearchResult {
350                    id: id as u64,
351                    distance,
352                })
353                .collect(),
354            duration_us,
355            error: String::new(),
356        }))
357    }
358    
359    async fn search_batch(
360        &self,
361        request: Request<SearchBatchRequest>,
362    ) -> Result<Response<SearchBatchResponse>, Status> {
363        let start = Instant::now();
364        let req = request.into_inner();
365        
366        let (index, dimension) = self.get_index_with_dim(&req.index_name)?;
367        let num_queries = req.num_queries as usize;
368        let k = req.k.max(1) as usize;
369        
370        // Validate
371        if req.queries.len() != num_queries * dimension {
372            return Err(Status::invalid_argument(format!(
373                "Query data size mismatch: expected {} floats, got {}",
374                num_queries * dimension,
375                req.queries.len()
376            )));
377        }
378        
379        // Perform batch search
380        let mut all_results = Vec::with_capacity(num_queries);
381        
382        for i in 0..num_queries {
383            let query = &req.queries[i * dimension..(i + 1) * dimension];
384            let results = match index.search(query, k) {
385                Ok(r) => r,
386                Err(_) => vec![],
387            };
388            
389            all_results.push(QueryResults {
390                results: results
391                    .into_iter()
392                    .map(|(id, distance)| SearchResult {
393                        id: id as u64,
394                        distance,
395                    })
396                    .collect(),
397            });
398        }
399        
400        let duration_us = start.elapsed().as_micros() as u64;
401        
402        Ok(Response::new(SearchBatchResponse {
403            results: all_results,
404            duration_us,
405        }))
406    }
407    
408    async fn get_stats(
409        &self,
410        request: Request<GetStatsRequest>,
411    ) -> Result<Response<GetStatsResponse>, Status> {
412        let name = request.into_inner().index_name;
413        
414        match self.indexes.get(&name) {
415            Some(entry) => {
416                let stats = entry.index.stats();
417                Ok(Response::new(GetStatsResponse {
418                    stats: Some(IndexStats {
419                        num_vectors: stats.num_vectors as u64,
420                        dimension: entry.dimension as u32,
421                        max_layer: stats.max_layer as u32,
422                        memory_bytes: 0, // Memory stats available via separate call
423                        avg_connections: stats.avg_connections,
424                    }),
425                    error: String::new(),
426                }))
427            }
428            None => Ok(Response::new(GetStatsResponse {
429                stats: None,
430                error: format!("Index '{}' not found", name),
431            })),
432        }
433    }
434    
435    async fn health_check(
436        &self,
437        _request: Request<HealthCheckRequest>,
438    ) -> Result<Response<HealthCheckResponse>, Status> {
439        let indexes: Vec<String> = self.indexes.iter().map(|e| e.name.clone()).collect();
440        
441        Ok(Response::new(HealthCheckResponse {
442            status: proto::health_check_response::Status::Serving.into(),
443            version: VERSION.to_string(),
444            indexes,
445        }))
446    }
447}