Skip to main content

sochdb_grpc/
semantic_cache_server.rs

1// Copyright 2025 Sushanth (https://github.com/sushanthpy)
2//
3// This program is free software: you can redistribute it and/or modify
4// you may not use this file except in compliance with the License.
5
6//! Semantic Cache Service gRPC Implementation
7//!
8//! Provides semantic caching for LLM queries via gRPC.
9
10use crate::proto::{
11    semantic_cache_service_server::{SemanticCacheService, SemanticCacheServiceServer},
12    SemanticCacheGetRequest, SemanticCacheGetResponse, SemanticCacheInvalidateRequest,
13    SemanticCacheInvalidateResponse, SemanticCachePutRequest, SemanticCachePutResponse,
14    SemanticCacheStatsRequest, SemanticCacheStatsResponse,
15};
16use dashmap::DashMap;
17use std::sync::atomic::{AtomicU64, Ordering};
18use std::time::{Duration, Instant};
19use tonic::{Request, Response, Status};
20
21/// Cache entry with embedding and TTL
22struct CacheEntry {
23    key: String,
24    value: String,
25    embedding: Vec<f32>,
26    expires_at: Option<Instant>,
27}
28
29/// Cache statistics
30struct CacheStats {
31    hits: AtomicU64,
32    misses: AtomicU64,
33}
34
35/// In-memory cache per cache name
36struct CacheInstance {
37    entries: DashMap<String, CacheEntry>,
38    stats: CacheStats,
39}
40
41impl CacheInstance {
42    fn new() -> Self {
43        Self {
44            entries: DashMap::new(),
45            stats: CacheStats {
46                hits: AtomicU64::new(0),
47                misses: AtomicU64::new(0),
48            },
49        }
50    }
51}
52
53/// Semantic Cache gRPC Server
54pub struct SemanticCacheServer {
55    caches: DashMap<String, CacheInstance>,
56}
57
58impl SemanticCacheServer {
59    pub fn new() -> Self {
60        Self {
61            caches: DashMap::new(),
62        }
63    }
64
65    pub fn into_service(self) -> SemanticCacheServiceServer<Self> {
66        SemanticCacheServiceServer::new(self)
67    }
68
69    fn get_or_create_cache(&self, name: &str) -> dashmap::mapref::one::Ref<'_, String, CacheInstance> {
70        if !self.caches.contains_key(name) {
71            self.caches.insert(name.to_string(), CacheInstance::new());
72        }
73        self.caches.get(name).unwrap()
74    }
75
76    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
77        if a.len() != b.len() || a.is_empty() {
78            return 0.0;
79        }
80
81        let mut dot = 0.0f32;
82        let mut norm_a = 0.0f32;
83        let mut norm_b = 0.0f32;
84
85        for i in 0..a.len() {
86            dot += a[i] * b[i];
87            norm_a += a[i] * a[i];
88            norm_b += b[i] * b[i];
89        }
90
91        if norm_a == 0.0 || norm_b == 0.0 {
92            return 0.0;
93        }
94
95        dot / (norm_a.sqrt() * norm_b.sqrt())
96    }
97}
98
99impl Default for SemanticCacheServer {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105#[tonic::async_trait]
106impl SemanticCacheService for SemanticCacheServer {
107    async fn get(
108        &self,
109        request: Request<SemanticCacheGetRequest>,
110    ) -> Result<Response<SemanticCacheGetResponse>, Status> {
111        let req = request.into_inner();
112        let cache = self.get_or_create_cache(&req.cache_name);
113        let now = Instant::now();
114
115        let mut best_match: Option<(String, String, f32)> = None;
116        let threshold = if req.similarity_threshold > 0.0 {
117            req.similarity_threshold
118        } else {
119            0.85 // Default threshold
120        };
121
122        // Search for semantically similar entries
123        for entry in cache.entries.iter() {
124            let e = entry.value();
125
126            // Check TTL
127            if let Some(expires_at) = e.expires_at {
128                if now > expires_at {
129                    continue;
130                }
131            }
132
133            let similarity = Self::cosine_similarity(&req.query_embedding, &e.embedding);
134            if similarity >= threshold {
135                match &best_match {
136                    Some((_, _, best_score)) if similarity > *best_score => {
137                        best_match = Some((e.key.clone(), e.value.clone(), similarity));
138                    }
139                    None => {
140                        best_match = Some((e.key.clone(), e.value.clone(), similarity));
141                    }
142                    _ => {}
143                }
144            }
145        }
146
147        match best_match {
148            Some((key, value, score)) => {
149                cache.stats.hits.fetch_add(1, Ordering::Relaxed);
150                Ok(Response::new(SemanticCacheGetResponse {
151                    hit: true,
152                    cached_value: value,
153                    similarity_score: score,
154                    matched_key: key,
155                }))
156            }
157            None => {
158                cache.stats.misses.fetch_add(1, Ordering::Relaxed);
159                Ok(Response::new(SemanticCacheGetResponse {
160                    hit: false,
161                    cached_value: String::new(),
162                    similarity_score: 0.0,
163                    matched_key: String::new(),
164                }))
165            }
166        }
167    }
168
169    async fn put(
170        &self,
171        request: Request<SemanticCachePutRequest>,
172    ) -> Result<Response<SemanticCachePutResponse>, Status> {
173        let req = request.into_inner();
174
175        if !self.caches.contains_key(&req.cache_name) {
176            self.caches.insert(req.cache_name.clone(), CacheInstance::new());
177        }
178
179        let cache = self.caches.get(&req.cache_name).unwrap();
180
181        let expires_at = if req.ttl_seconds > 0 {
182            Some(Instant::now() + Duration::from_secs(req.ttl_seconds))
183        } else {
184            None
185        };
186
187        let entry = CacheEntry {
188            key: req.key.clone(),
189            value: req.value,
190            embedding: req.key_embedding,
191            expires_at,
192        };
193
194        cache.entries.insert(req.key, entry);
195
196        Ok(Response::new(SemanticCachePutResponse {
197            success: true,
198            error: String::new(),
199        }))
200    }
201
202    async fn invalidate(
203        &self,
204        request: Request<SemanticCacheInvalidateRequest>,
205    ) -> Result<Response<SemanticCacheInvalidateResponse>, Status> {
206        let req = request.into_inner();
207
208        let count = if let Some(cache) = self.caches.get(&req.cache_name) {
209            if req.pattern.is_empty() {
210                let count = cache.entries.len();
211                cache.entries.clear();
212                count as u32
213            } else {
214                let mut count = 0u32;
215                cache.entries.retain(|k, _| {
216                    if k.contains(&req.pattern) {
217                        count += 1;
218                        false
219                    } else {
220                        true
221                    }
222                });
223                count
224            }
225        } else {
226            0
227        };
228
229        Ok(Response::new(SemanticCacheInvalidateResponse {
230            invalidated_count: count,
231        }))
232    }
233
234    async fn get_stats(
235        &self,
236        request: Request<SemanticCacheStatsRequest>,
237    ) -> Result<Response<SemanticCacheStatsResponse>, Status> {
238        let req = request.into_inner();
239
240        match self.caches.get(&req.cache_name) {
241            Some(cache) => {
242                let hits = cache.stats.hits.load(Ordering::Relaxed);
243                let misses = cache.stats.misses.load(Ordering::Relaxed);
244                let total = hits + misses;
245                let hit_rate = if total > 0 {
246                    hits as f32 / total as f32
247                } else {
248                    0.0
249                };
250
251                Ok(Response::new(SemanticCacheStatsResponse {
252                    hits,
253                    misses,
254                    entry_count: cache.entries.len() as u64,
255                    hit_rate,
256                }))
257            }
258            None => Ok(Response::new(SemanticCacheStatsResponse {
259                hits: 0,
260                misses: 0,
261                entry_count: 0,
262                hit_rate: 0.0,
263            })),
264        }
265    }
266}