1use crate::proto::{
11 semantic_cache_service_server::{SemanticCacheService, SemanticCacheServiceServer},
12 SemanticCacheGetRequest, SemanticCacheGetResponse, SemanticCacheInvalidateRequest,
13 SemanticCacheInvalidateResponse, SemanticCachePutRequest, SemanticCachePutResponse,
14 SemanticCacheStatsRequest, SemanticCacheStatsResponse,
15};
16use dashmap::DashMap;
17use std::sync::atomic::{AtomicU64, Ordering};
18use std::time::{Duration, Instant};
19use tonic::{Request, Response, Status};
20
21struct CacheEntry {
23 key: String,
24 value: String,
25 embedding: Vec<f32>,
26 expires_at: Option<Instant>,
27}
28
29struct CacheStats {
31 hits: AtomicU64,
32 misses: AtomicU64,
33}
34
35struct CacheInstance {
37 entries: DashMap<String, CacheEntry>,
38 stats: CacheStats,
39}
40
41impl CacheInstance {
42 fn new() -> Self {
43 Self {
44 entries: DashMap::new(),
45 stats: CacheStats {
46 hits: AtomicU64::new(0),
47 misses: AtomicU64::new(0),
48 },
49 }
50 }
51}
52
53pub struct SemanticCacheServer {
55 caches: DashMap<String, CacheInstance>,
56}
57
58impl SemanticCacheServer {
59 pub fn new() -> Self {
60 Self {
61 caches: DashMap::new(),
62 }
63 }
64
65 pub fn into_service(self) -> SemanticCacheServiceServer<Self> {
66 SemanticCacheServiceServer::new(self)
67 }
68
69 fn get_or_create_cache(&self, name: &str) -> dashmap::mapref::one::Ref<'_, String, CacheInstance> {
70 if !self.caches.contains_key(name) {
71 self.caches.insert(name.to_string(), CacheInstance::new());
72 }
73 self.caches.get(name).unwrap()
74 }
75
76 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
77 if a.len() != b.len() || a.is_empty() {
78 return 0.0;
79 }
80
81 let mut dot = 0.0f32;
82 let mut norm_a = 0.0f32;
83 let mut norm_b = 0.0f32;
84
85 for i in 0..a.len() {
86 dot += a[i] * b[i];
87 norm_a += a[i] * a[i];
88 norm_b += b[i] * b[i];
89 }
90
91 if norm_a == 0.0 || norm_b == 0.0 {
92 return 0.0;
93 }
94
95 dot / (norm_a.sqrt() * norm_b.sqrt())
96 }
97}
98
99impl Default for SemanticCacheServer {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105#[tonic::async_trait]
106impl SemanticCacheService for SemanticCacheServer {
107 async fn get(
108 &self,
109 request: Request<SemanticCacheGetRequest>,
110 ) -> Result<Response<SemanticCacheGetResponse>, Status> {
111 let req = request.into_inner();
112 let cache = self.get_or_create_cache(&req.cache_name);
113 let now = Instant::now();
114
115 let mut best_match: Option<(String, String, f32)> = None;
116 let threshold = if req.similarity_threshold > 0.0 {
117 req.similarity_threshold
118 } else {
119 0.85 };
121
122 for entry in cache.entries.iter() {
124 let e = entry.value();
125
126 if let Some(expires_at) = e.expires_at {
128 if now > expires_at {
129 continue;
130 }
131 }
132
133 let similarity = Self::cosine_similarity(&req.query_embedding, &e.embedding);
134 if similarity >= threshold {
135 match &best_match {
136 Some((_, _, best_score)) if similarity > *best_score => {
137 best_match = Some((e.key.clone(), e.value.clone(), similarity));
138 }
139 None => {
140 best_match = Some((e.key.clone(), e.value.clone(), similarity));
141 }
142 _ => {}
143 }
144 }
145 }
146
147 match best_match {
148 Some((key, value, score)) => {
149 cache.stats.hits.fetch_add(1, Ordering::Relaxed);
150 Ok(Response::new(SemanticCacheGetResponse {
151 hit: true,
152 cached_value: value,
153 similarity_score: score,
154 matched_key: key,
155 }))
156 }
157 None => {
158 cache.stats.misses.fetch_add(1, Ordering::Relaxed);
159 Ok(Response::new(SemanticCacheGetResponse {
160 hit: false,
161 cached_value: String::new(),
162 similarity_score: 0.0,
163 matched_key: String::new(),
164 }))
165 }
166 }
167 }
168
169 async fn put(
170 &self,
171 request: Request<SemanticCachePutRequest>,
172 ) -> Result<Response<SemanticCachePutResponse>, Status> {
173 let req = request.into_inner();
174
175 if !self.caches.contains_key(&req.cache_name) {
176 self.caches.insert(req.cache_name.clone(), CacheInstance::new());
177 }
178
179 let cache = self.caches.get(&req.cache_name).unwrap();
180
181 let expires_at = if req.ttl_seconds > 0 {
182 Some(Instant::now() + Duration::from_secs(req.ttl_seconds))
183 } else {
184 None
185 };
186
187 let entry = CacheEntry {
188 key: req.key.clone(),
189 value: req.value,
190 embedding: req.key_embedding,
191 expires_at,
192 };
193
194 cache.entries.insert(req.key, entry);
195
196 Ok(Response::new(SemanticCachePutResponse {
197 success: true,
198 error: String::new(),
199 }))
200 }
201
202 async fn invalidate(
203 &self,
204 request: Request<SemanticCacheInvalidateRequest>,
205 ) -> Result<Response<SemanticCacheInvalidateResponse>, Status> {
206 let req = request.into_inner();
207
208 let count = if let Some(cache) = self.caches.get(&req.cache_name) {
209 if req.pattern.is_empty() {
210 let count = cache.entries.len();
211 cache.entries.clear();
212 count as u32
213 } else {
214 let mut count = 0u32;
215 cache.entries.retain(|k, _| {
216 if k.contains(&req.pattern) {
217 count += 1;
218 false
219 } else {
220 true
221 }
222 });
223 count
224 }
225 } else {
226 0
227 };
228
229 Ok(Response::new(SemanticCacheInvalidateResponse {
230 invalidated_count: count,
231 }))
232 }
233
234 async fn get_stats(
235 &self,
236 request: Request<SemanticCacheStatsRequest>,
237 ) -> Result<Response<SemanticCacheStatsResponse>, Status> {
238 let req = request.into_inner();
239
240 match self.caches.get(&req.cache_name) {
241 Some(cache) => {
242 let hits = cache.stats.hits.load(Ordering::Relaxed);
243 let misses = cache.stats.misses.load(Ordering::Relaxed);
244 let total = hits + misses;
245 let hit_rate = if total > 0 {
246 hits as f32 / total as f32
247 } else {
248 0.0
249 };
250
251 Ok(Response::new(SemanticCacheStatsResponse {
252 hits,
253 misses,
254 entry_count: cache.entries.len() as u64,
255 hit_rate,
256 }))
257 }
258 None => Ok(Response::new(SemanticCacheStatsResponse {
259 hits: 0,
260 misses: 0,
261 entry_count: 0,
262 hit_rate: 0.0,
263 })),
264 }
265 }
266}