toondb_grpc/
server.rs

1// Copyright 2025 Sushanth (https://github.com/sushanthpy)
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! gRPC Server implementation for Vector Index Service
16
17use crate::error::GrpcError;
18use crate::proto::{
19    self,
20    vector_index_service_server::{VectorIndexService, VectorIndexServiceServer},
21    CreateIndexRequest, CreateIndexResponse, DropIndexRequest, DropIndexResponse,
22    GetStatsRequest, GetStatsResponse, HealthCheckRequest, HealthCheckResponse,
23    HnswConfig as ProtoHnswConfig, IndexInfo, IndexStats, InsertBatchRequest,
24    InsertBatchResponse, InsertStreamRequest, InsertStreamResponse, QueryResults,
25    SearchBatchRequest, SearchBatchResponse, SearchRequest, SearchResponse, SearchResult,
26};
27use dashmap::DashMap;
28use std::sync::Arc;
29use std::time::Instant;
30use tokio_stream::StreamExt;
31use tonic::{Request, Response, Status, Streaming};
32use toondb_index::hnsw::{DistanceMetric, HnswConfig, HnswIndex};
33
34/// Server version
35const VERSION: &str = env!("CARGO_PKG_VERSION");
36
37/// Metadata for an index
38#[allow(dead_code)]
39struct IndexEntry {
40    index: Arc<HnswIndex>,
41    name: String,
42    dimension: usize,
43    metric: proto::DistanceMetric,
44    config: ProtoHnswConfig,
45    created_at: u64,
46}
47
48/// Vector Index gRPC Server
49pub struct VectorIndexServer {
50    /// Map of index name -> index entry
51    indexes: DashMap<String, IndexEntry>,
52}
53
54impl VectorIndexServer {
55    /// Create a new server instance
56    pub fn new() -> Self {
57        Self {
58            indexes: DashMap::new(),
59        }
60    }
61    
62    /// Create the gRPC service
63    pub fn into_service(self) -> VectorIndexServiceServer<Self> {
64        VectorIndexServiceServer::new(self)
65    }
66    
67    /// Get an index by name and its dimension
68    fn get_index_with_dim(&self, name: &str) -> Result<(Arc<HnswIndex>, usize), GrpcError> {
69        self.indexes
70            .get(name)
71            .map(|entry| (entry.index.clone(), entry.dimension))
72            .ok_or_else(|| GrpcError::IndexNotFound(name.to_string()))
73    }
74    
75    /// Get an index by name
76    fn get_index(&self, name: &str) -> Result<Arc<HnswIndex>, GrpcError> {
77        self.indexes
78            .get(name)
79            .map(|entry| entry.index.clone())
80            .ok_or_else(|| GrpcError::IndexNotFound(name.to_string()))
81    }
82    
83    /// Convert proto metric to internal metric
84    fn convert_metric(metric: proto::DistanceMetric) -> DistanceMetric {
85        match metric {
86            proto::DistanceMetric::L2 => DistanceMetric::Euclidean,
87            proto::DistanceMetric::Cosine => DistanceMetric::Cosine,
88            proto::DistanceMetric::DotProduct => DistanceMetric::DotProduct,
89            _ => DistanceMetric::Cosine, // Default
90        }
91    }
92}
93
94impl Default for VectorIndexServer {
95    fn default() -> Self {
96        Self::new()
97    }
98}
99
100#[tonic::async_trait]
101impl VectorIndexService for VectorIndexServer {
102    async fn create_index(
103        &self,
104        request: Request<CreateIndexRequest>,
105    ) -> Result<Response<CreateIndexResponse>, Status> {
106        let req = request.into_inner();
107        let name = req.name.clone();
108        
109        // Check if index already exists
110        if self.indexes.contains_key(&name) {
111            return Ok(Response::new(CreateIndexResponse {
112                success: false,
113                error: format!("Index '{}' already exists", name),
114                info: None,
115            }));
116        }
117        
118        // Build config
119        let proto_config = req.config.unwrap_or_default();
120        let config = HnswConfig {
121            max_connections: if proto_config.max_connections > 0 {
122                proto_config.max_connections as usize
123            } else {
124                16
125            },
126            max_connections_layer0: if proto_config.max_connections_layer0 > 0 {
127                proto_config.max_connections_layer0 as usize
128            } else {
129                32
130            },
131            ef_construction: if proto_config.ef_construction > 0 {
132                proto_config.ef_construction as usize
133            } else {
134                200
135            },
136            ef_search: if proto_config.ef_search > 0 {
137                proto_config.ef_search as usize
138            } else {
139                50
140            },
141            metric: Self::convert_metric(req.metric()),
142            ..Default::default()
143        };
144        
145        let dimension = req.dimension as usize;
146        let index = HnswIndex::new(dimension, config.clone());
147        let created_at = std::time::SystemTime::now()
148            .duration_since(std::time::UNIX_EPOCH)
149            .unwrap()
150            .as_secs();
151        
152        let entry = IndexEntry {
153            index: Arc::new(index),
154            name: name.clone(),
155            dimension,
156            metric: req.metric(),
157            config: proto_config.clone(),
158            created_at,
159        };
160        
161        self.indexes.insert(name.clone(), entry);
162        
163        tracing::info!("Created index '{}' with dimension {}", name, dimension);
164        
165        Ok(Response::new(CreateIndexResponse {
166            success: true,
167            error: String::new(),
168            info: Some(IndexInfo {
169                name,
170                dimension: dimension as u32,
171                metric: req.metric.into(),
172                config: Some(proto_config),
173                created_at,
174            }),
175        }))
176    }
177    
178    async fn drop_index(
179        &self,
180        request: Request<DropIndexRequest>,
181    ) -> Result<Response<DropIndexResponse>, Status> {
182        let name = request.into_inner().name;
183        
184        match self.indexes.remove(&name) {
185            Some(_) => {
186                tracing::info!("Dropped index '{}'", name);
187                Ok(Response::new(DropIndexResponse {
188                    success: true,
189                    error: String::new(),
190                }))
191            }
192            None => Ok(Response::new(DropIndexResponse {
193                success: false,
194                error: format!("Index '{}' not found", name),
195            })),
196        }
197    }
198    
199    async fn insert_batch(
200        &self,
201        request: Request<InsertBatchRequest>,
202    ) -> Result<Response<InsertBatchResponse>, Status> {
203        let start = Instant::now();
204        let req = request.into_inner();
205        
206        let (index, dimension) = self.get_index_with_dim(&req.index_name)?;
207        
208        // Validate input
209        if req.vectors.len() != req.ids.len() * dimension {
210            return Err(Status::invalid_argument(format!(
211                "Vector data size mismatch: expected {} floats, got {}",
212                req.ids.len() * dimension,
213                req.vectors.len()
214            )));
215        }
216        
217        // Convert IDs to u128
218        let ids: Vec<u128> = req.ids.iter().map(|&id| id as u128).collect();
219        
220        // Use flat batch insert for zero-copy performance
221        match index.insert_batch_flat(&ids, &req.vectors, dimension) {
222            Ok(count) => {
223                let duration_us = start.elapsed().as_micros() as u64;
224                tracing::debug!(
225                    "Inserted {} vectors into '{}' in {}µs",
226                    count,
227                    req.index_name,
228                    duration_us
229                );
230                Ok(Response::new(InsertBatchResponse {
231                    inserted_count: count as u32,
232                    error: String::new(),
233                    duration_us,
234                }))
235            }
236            Err(e) => Ok(Response::new(InsertBatchResponse {
237                inserted_count: 0,
238                error: e,
239                duration_us: start.elapsed().as_micros() as u64,
240            })),
241        }
242    }
243    
244    async fn insert_stream(
245        &self,
246        request: Request<Streaming<InsertStreamRequest>>,
247    ) -> Result<Response<InsertStreamResponse>, Status> {
248        let start = Instant::now();
249        let mut stream = request.into_inner();
250        
251        let mut index_name: Option<String> = None;
252        let mut index: Option<Arc<HnswIndex>> = None;
253        let mut total_inserted = 0u32;
254        let mut errors = Vec::new();
255        
256        while let Some(result) = stream.next().await {
257            match result {
258                Ok(req) => {
259                    // Get index on first message
260                    if index.is_none() {
261                        if req.index_name.is_empty() {
262                            errors.push("First message must include index_name".to_string());
263                            continue;
264                        }
265                        index_name = Some(req.index_name.clone());
266                        match self.get_index(&req.index_name) {
267                            Ok(idx) => index = Some(idx),
268                            Err(e) => {
269                                errors.push(e.to_string());
270                                break;
271                            }
272                        }
273                    }
274                    
275                    // Insert the vector
276                    if let Some(ref idx) = index {
277                        let vector: Vec<f32> = req.vector;
278                        match idx.insert_one_from_slice(req.id as u128, &vector) {
279                            Ok(()) => total_inserted += 1,
280                            Err(e) => errors.push(format!("ID {}: {}", req.id, e)),
281                        }
282                    }
283                }
284                Err(e) => {
285                    errors.push(format!("Stream error: {}", e));
286                    break;
287                }
288            }
289        }
290        
291        let duration_us = start.elapsed().as_micros() as u64;
292        
293        if let Some(name) = &index_name {
294            tracing::debug!(
295                "Stream inserted {} vectors into '{}' in {}µs",
296                total_inserted,
297                name,
298                duration_us
299            );
300        }
301        
302        Ok(Response::new(InsertStreamResponse {
303            total_inserted,
304            errors,
305            duration_us,
306        }))
307    }
308    
309    async fn search(
310        &self,
311        request: Request<SearchRequest>,
312    ) -> Result<Response<SearchResponse>, Status> {
313        let start = Instant::now();
314        let req = request.into_inner();
315        
316        let (index, dimension) = self.get_index_with_dim(&req.index_name)?;
317        
318        // Validate dimension
319        if req.query.len() != dimension {
320            return Err(Status::invalid_argument(format!(
321                "Query dimension mismatch: expected {}, got {}",
322                dimension,
323                req.query.len()
324            )));
325        }
326        
327        let k = req.k.max(1) as usize;
328        
329        // Perform search
330        let results = match index.search(&req.query, k) {
331            Ok(r) => r,
332            Err(e) => {
333                return Ok(Response::new(SearchResponse {
334                    results: vec![],
335                    duration_us: start.elapsed().as_micros() as u64,
336                    error: e,
337                }));
338            }
339        };
340        
341        let duration_us = start.elapsed().as_micros() as u64;
342        
343        Ok(Response::new(SearchResponse {
344            results: results
345                .into_iter()
346                .map(|(id, distance)| SearchResult {
347                    id: id as u64,
348                    distance,
349                })
350                .collect(),
351            duration_us,
352            error: String::new(),
353        }))
354    }
355    
356    async fn search_batch(
357        &self,
358        request: Request<SearchBatchRequest>,
359    ) -> Result<Response<SearchBatchResponse>, Status> {
360        let start = Instant::now();
361        let req = request.into_inner();
362        
363        let (index, dimension) = self.get_index_with_dim(&req.index_name)?;
364        let num_queries = req.num_queries as usize;
365        let k = req.k.max(1) as usize;
366        
367        // Validate
368        if req.queries.len() != num_queries * dimension {
369            return Err(Status::invalid_argument(format!(
370                "Query data size mismatch: expected {} floats, got {}",
371                num_queries * dimension,
372                req.queries.len()
373            )));
374        }
375        
376        // Perform batch search
377        let mut all_results = Vec::with_capacity(num_queries);
378        
379        for i in 0..num_queries {
380            let query = &req.queries[i * dimension..(i + 1) * dimension];
381            let results = match index.search(query, k) {
382                Ok(r) => r,
383                Err(_) => vec![],
384            };
385            
386            all_results.push(QueryResults {
387                results: results
388                    .into_iter()
389                    .map(|(id, distance)| SearchResult {
390                        id: id as u64,
391                        distance,
392                    })
393                    .collect(),
394            });
395        }
396        
397        let duration_us = start.elapsed().as_micros() as u64;
398        
399        Ok(Response::new(SearchBatchResponse {
400            results: all_results,
401            duration_us,
402        }))
403    }
404    
405    async fn get_stats(
406        &self,
407        request: Request<GetStatsRequest>,
408    ) -> Result<Response<GetStatsResponse>, Status> {
409        let name = request.into_inner().index_name;
410        
411        match self.indexes.get(&name) {
412            Some(entry) => {
413                let stats = entry.index.stats();
414                Ok(Response::new(GetStatsResponse {
415                    stats: Some(IndexStats {
416                        num_vectors: stats.num_vectors as u64,
417                        dimension: entry.dimension as u32,
418                        max_layer: stats.max_layer as u32,
419                        memory_bytes: 0, // Memory stats available via separate call
420                        avg_connections: stats.avg_connections,
421                    }),
422                    error: String::new(),
423                }))
424            }
425            None => Ok(Response::new(GetStatsResponse {
426                stats: None,
427                error: format!("Index '{}' not found", name),
428            })),
429        }
430    }
431    
432    async fn health_check(
433        &self,
434        _request: Request<HealthCheckRequest>,
435    ) -> Result<Response<HealthCheckResponse>, Status> {
436        let indexes: Vec<String> = self.indexes.iter().map(|e| e.name.clone()).collect();
437        
438        Ok(Response::new(HealthCheckResponse {
439            status: proto::health_check_response::Status::Serving.into(),
440            version: VERSION.to_string(),
441            indexes,
442        }))
443    }
444}