1use crate::error::{SfError, SfResult};
6use moka::future::Cache;
7use serde::{Deserialize, Serialize};
8use std::hash::{Hash, Hasher};
9use std::sync::Arc;
10use std::time::Duration;
11use tracing::{debug, info};
12
13#[derive(Debug, Clone)]
15pub struct CacheConfig {
16 pub max_capacity: u64,
18
19 pub ttl: Duration,
21
22 pub tti: Option<Duration>,
24}
25
26impl Default for CacheConfig {
27 fn default() -> Self {
28 Self {
29 max_capacity: 10_000,
30 ttl: Duration::from_secs(300), tti: Some(Duration::from_secs(60)), }
33 }
34}
35
36impl CacheConfig {
37 pub fn new() -> Self {
39 Self::default()
40 }
41
42 pub fn max_capacity(mut self, capacity: u64) -> Self {
44 self.max_capacity = capacity;
45 self
46 }
47
48 pub fn ttl(mut self, duration: Duration) -> Self {
50 self.ttl = duration;
51 self
52 }
53
54 pub fn tti(mut self, duration: Duration) -> Self {
56 self.tti = Some(duration);
57 self
58 }
59
60 pub fn disabled() -> Self {
62 Self {
63 max_capacity: 0,
64 ttl: Duration::from_secs(0),
65 tti: None,
66 }
67 }
68}
69
70#[derive(Debug, Clone, PartialEq, Eq, Hash)]
72struct QueryKey {
73 query: String,
74}
75
76impl QueryKey {
77 fn new(query: impl Into<String>) -> Self {
78 Self {
79 query: query.into(),
80 }
81 }
82}
83
84#[derive(Debug, Clone, PartialEq, Eq)]
86struct RecordKey {
87 sobject: String,
88 id: String,
89}
90
91impl Hash for RecordKey {
92 fn hash<H: Hasher>(&self, state: &mut H) {
93 self.sobject.hash(state);
94 self.id.hash(state);
95 }
96}
97
98impl RecordKey {
99 fn new(sobject: impl Into<String>, id: impl Into<String>) -> Self {
100 Self {
101 sobject: sobject.into(),
102 id: id.into(),
103 }
104 }
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109struct CachedValue<T> {
110 data: T,
111 cached_at: i64, }
113
114impl<T> CachedValue<T> {
115 fn new(data: T) -> Self {
116 Self {
117 data,
118 cached_at: chrono::Utc::now().timestamp(),
119 }
120 }
121}
122
123pub struct QueryCache {
125 cache: Arc<Cache<QueryKey, Vec<u8>>>,
126 enabled: bool,
127}
128
129impl QueryCache {
130 pub fn new(config: CacheConfig) -> Self {
132 let enabled = config.max_capacity > 0 && config.ttl.as_secs() > 0;
133
134 let cache = Cache::builder()
135 .max_capacity(config.max_capacity)
136 .time_to_live(config.ttl)
137 .time_to_idle(config.tti.unwrap_or(config.ttl))
138 .build();
139
140 if enabled {
141 info!(
142 "Query cache enabled with capacity {} and TTL {:?}",
143 config.max_capacity, config.ttl
144 );
145 } else {
146 info!("Query cache disabled");
147 }
148
149 Self {
150 cache: Arc::new(cache),
151 enabled,
152 }
153 }
154
155 pub async fn get<T>(&self, query: &str) -> Option<Vec<T>>
157 where
158 T: for<'de> Deserialize<'de>,
159 {
160 if !self.enabled {
161 return None;
162 }
163
164 let key = QueryKey::new(query);
165
166 if let Some(cached_bytes) = self.cache.get(&key).await {
167 match serde_json::from_slice::<CachedValue<Vec<T>>>(&cached_bytes) {
168 Ok(cached_value) => {
169 debug!("Cache hit for query: {}", query);
170 Some(cached_value.data)
171 }
172 Err(e) => {
173 debug!("Cache deserialization error: {}", e);
174 None
175 }
176 }
177 } else {
178 debug!("Cache miss for query: {}", query);
179 None
180 }
181 }
182
183 pub async fn set<T>(&self, query: &str, data: Vec<T>) -> SfResult<()>
185 where
186 T: Serialize,
187 {
188 if !self.enabled {
189 return Ok(());
190 }
191
192 let key = QueryKey::new(query);
193 let cached_value = CachedValue::new(data);
194
195 match serde_json::to_vec(&cached_value) {
196 Ok(bytes) => {
197 self.cache.insert(key, bytes).await;
198 debug!("Cached query results: {}", query);
199 Ok(())
200 }
201 Err(e) => {
202 debug!("Failed to serialize cache entry: {}", e);
203 Err(SfError::Cache(format!("Serialization failed: {}", e)))
204 }
205 }
206 }
207
208 pub async fn invalidate(&self, query: &str) {
210 if !self.enabled {
211 return;
212 }
213
214 let key = QueryKey::new(query);
215 self.cache.invalidate(&key).await;
216 debug!("Invalidated cache for query: {}", query);
217 }
218
219 pub async fn clear(&self) {
221 if !self.enabled {
222 return;
223 }
224
225 self.cache.invalidate_all();
226 info!("Cleared all query cache entries");
227 }
228
229 pub fn stats(&self) -> CacheStats {
231 CacheStats {
232 entry_count: self.cache.entry_count(),
233 weighted_size: self.cache.weighted_size(),
234 }
235 }
236}
237
238#[derive(Debug, Clone)]
240pub struct CacheStats {
241 pub entry_count: u64,
243
244 pub weighted_size: u64,
246}
247
248pub struct RecordCache {
250 cache: Arc<Cache<RecordKey, Vec<u8>>>,
251 enabled: bool,
252}
253
254impl RecordCache {
255 pub fn new(config: CacheConfig) -> Self {
257 let enabled = config.max_capacity > 0 && config.ttl.as_secs() > 0;
258
259 let cache = Cache::builder()
260 .max_capacity(config.max_capacity)
261 .time_to_live(config.ttl)
262 .time_to_idle(config.tti.unwrap_or(config.ttl))
263 .build();
264
265 Self {
266 cache: Arc::new(cache),
267 enabled,
268 }
269 }
270
271 pub async fn get<T>(&self, sobject: &str, id: &str) -> Option<T>
273 where
274 T: for<'de> Deserialize<'de>,
275 {
276 if !self.enabled {
277 return None;
278 }
279
280 let key = RecordKey::new(sobject, id);
281
282 if let Some(cached_bytes) = self.cache.get(&key).await {
283 match serde_json::from_slice::<CachedValue<T>>(&cached_bytes) {
284 Ok(cached_value) => {
285 debug!("Cache hit for {} {}", sobject, id);
286 Some(cached_value.data)
287 }
288 Err(e) => {
289 debug!("Cache deserialization error: {}", e);
290 None
291 }
292 }
293 } else {
294 None
295 }
296 }
297
298 pub async fn set<T>(&self, sobject: &str, id: &str, data: T) -> SfResult<()>
300 where
301 T: Serialize,
302 {
303 if !self.enabled {
304 return Ok(());
305 }
306
307 let key = RecordKey::new(sobject, id);
308 let cached_value = CachedValue::new(data);
309
310 match serde_json::to_vec(&cached_value) {
311 Ok(bytes) => {
312 self.cache.insert(key, bytes).await;
313 debug!("Cached {} {}", sobject, id);
314 Ok(())
315 }
316 Err(e) => Err(SfError::Cache(format!("Serialization failed: {}", e))),
317 }
318 }
319
320 pub async fn invalidate(&self, sobject: &str, id: &str) {
322 if !self.enabled {
323 return;
324 }
325
326 let key = RecordKey::new(sobject, id);
327 self.cache.invalidate(&key).await;
328 debug!("Invalidated cache for {} {}", sobject, id);
329 }
330
331 pub async fn invalidate_sobject(&self, sobject: &str) {
333 if !self.enabled {
334 return;
335 }
336
337 let sobject_owned = sobject.to_string();
340 let _ = self
341 .cache
342 .invalidate_entries_if(move |key, _| key.sobject == sobject_owned);
343 info!("Invalidated all cached {} records", sobject);
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350 use serde::Deserialize;
351
352 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
353 struct TestRecord {
354 id: String,
355 name: String,
356 }
357
358 #[tokio::test]
359 async fn test_query_cache() {
360 let config = CacheConfig::new().ttl(Duration::from_secs(60));
361 let cache = QueryCache::new(config);
362
363 let query = "SELECT Id FROM Account";
364 let data = vec![TestRecord {
365 id: "1".to_string(),
366 name: "Test".to_string(),
367 }];
368
369 assert!(cache.get::<TestRecord>(query).await.is_none());
371
372 cache.set(query, data.clone()).await.unwrap();
374
375 let cached = cache.get::<TestRecord>(query).await;
377 assert!(cached.is_some());
378 assert_eq!(cached.unwrap(), data);
379 }
380
381 #[tokio::test]
382 async fn test_cache_disabled() {
383 let config = CacheConfig::disabled();
384 let cache = QueryCache::new(config);
385
386 let query = "SELECT Id FROM Account";
387 let data = vec![TestRecord {
388 id: "1".to_string(),
389 name: "Test".to_string(),
390 }];
391
392 cache.set(query, data).await.unwrap();
393
394 assert!(cache.get::<TestRecord>(query).await.is_none());
396 }
397}