zeph_memory/
response_cache.rs1use sqlx::SqlitePool;
5
6use crate::error::MemoryError;
7
8pub struct ResponseCache {
9 pool: SqlitePool,
10 ttl_secs: u64,
11}
12
13impl ResponseCache {
14 #[must_use]
15 pub fn new(pool: SqlitePool, ttl_secs: u64) -> Self {
16 Self { pool, ttl_secs }
17 }
18
19 pub async fn get(&self, key: &str) -> Result<Option<String>, MemoryError> {
25 let now = unix_now();
26 let row: Option<(String,)> = sqlx::query_as(
27 "SELECT response FROM response_cache WHERE cache_key = ? AND expires_at > ?",
28 )
29 .bind(key)
30 .bind(now)
31 .fetch_optional(&self.pool)
32 .await?;
33 Ok(row.map(|(r,)| r))
34 }
35
36 pub async fn put(&self, key: &str, response: &str, model: &str) -> Result<(), MemoryError> {
42 let now = unix_now();
43 let expires_at = now.saturating_add(self.ttl_secs.min(31_536_000).cast_signed());
45 sqlx::query(
46 "INSERT OR REPLACE INTO response_cache (cache_key, response, model, created_at, expires_at) \
47 VALUES (?, ?, ?, ?, ?)",
48 )
49 .bind(key)
50 .bind(response)
51 .bind(model)
52 .bind(now)
53 .bind(expires_at)
54 .execute(&self.pool)
55 .await?;
56 Ok(())
57 }
58
59 pub async fn cleanup_expired(&self) -> Result<u64, MemoryError> {
65 let now = unix_now();
66 let result = sqlx::query("DELETE FROM response_cache WHERE expires_at <= ?")
67 .bind(now)
68 .execute(&self.pool)
69 .await?;
70 Ok(result.rows_affected())
71 }
72
73 #[must_use]
80 pub fn compute_key(last_user_message: &str, model: &str) -> String {
81 let mut hasher = blake3::Hasher::new();
82 let content = last_user_message.as_bytes();
83 hasher.update(&(content.len() as u64).to_le_bytes());
84 hasher.update(content);
85 let model_bytes = model.as_bytes();
86 hasher.update(&(model_bytes.len() as u64).to_le_bytes());
87 hasher.update(model_bytes);
88 hasher.finalize().to_hex().to_string()
89 }
90}
91
92fn unix_now() -> i64 {
93 std::time::SystemTime::now()
94 .duration_since(std::time::UNIX_EPOCH)
95 .unwrap_or_default()
96 .as_secs()
97 .cast_signed()
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use crate::sqlite::SqliteStore;
104
105 async fn test_cache() -> ResponseCache {
106 let store = SqliteStore::new(":memory:").await.unwrap();
107 ResponseCache::new(store.pool().clone(), 3600)
108 }
109
110 #[tokio::test]
111 async fn cache_miss_returns_none() {
112 let cache = test_cache().await;
113 let result = cache.get("nonexistent").await.unwrap();
114 assert!(result.is_none());
115 }
116
117 #[tokio::test]
118 async fn cache_put_and_get_roundtrip() {
119 let cache = test_cache().await;
120 cache.put("key1", "response text", "gpt-4").await.unwrap();
121 let result = cache.get("key1").await.unwrap();
122 assert_eq!(result.as_deref(), Some("response text"));
123 }
124
125 #[tokio::test]
126 async fn cache_expired_entry_returns_none() {
127 let store = SqliteStore::new(":memory:").await.unwrap();
128 let cache = ResponseCache::new(store.pool().clone(), 0);
129 cache.put("key1", "response", "model").await.unwrap();
131 let result = cache.get("key1").await.unwrap();
133 assert!(result.is_none());
134 }
135
136 #[tokio::test]
137 async fn cleanup_expired_removes_entries() {
138 let store = SqliteStore::new(":memory:").await.unwrap();
139 let cache = ResponseCache::new(store.pool().clone(), 0);
140 cache.put("key1", "response", "model").await.unwrap();
141 let deleted = cache.cleanup_expired().await.unwrap();
142 assert!(deleted > 0);
143 }
144
145 #[tokio::test]
146 async fn cleanup_does_not_remove_valid_entries() {
147 let cache = test_cache().await;
148 cache.put("key1", "response", "model").await.unwrap();
149 let deleted = cache.cleanup_expired().await.unwrap();
150 assert_eq!(deleted, 0);
151 let result = cache.get("key1").await.unwrap();
152 assert!(result.is_some());
153 }
154
155 #[test]
156 fn compute_key_deterministic() {
157 let k1 = ResponseCache::compute_key("hello", "gpt-4");
158 let k2 = ResponseCache::compute_key("hello", "gpt-4");
159 assert_eq!(k1, k2);
160 }
161
162 #[test]
163 fn compute_key_different_for_different_content() {
164 assert_ne!(
165 ResponseCache::compute_key("hello", "gpt-4"),
166 ResponseCache::compute_key("world", "gpt-4")
167 );
168 }
169
170 #[test]
171 fn compute_key_different_for_different_model() {
172 assert_ne!(
173 ResponseCache::compute_key("hello", "gpt-4"),
174 ResponseCache::compute_key("hello", "gpt-3.5")
175 );
176 }
177
178 #[test]
179 fn compute_key_empty_message() {
180 let k = ResponseCache::compute_key("", "model");
181 assert!(!k.is_empty());
182 }
183
184 #[tokio::test]
185 async fn ttl_extreme_value_does_not_overflow() {
186 let store = SqliteStore::new(":memory:").await.unwrap();
187 let cache = ResponseCache::new(store.pool().clone(), u64::MAX - 1);
189 cache.put("key1", "response", "model").await.unwrap();
191 let result = cache.get("key1").await.unwrap();
193 assert_eq!(result.as_deref(), Some("response"));
194 }
195
196 #[tokio::test]
197 async fn insert_or_replace_updates_existing_entry() {
198 let cache = test_cache().await;
199 cache.put("key1", "first response", "gpt-4").await.unwrap();
200 cache.put("key1", "second response", "gpt-4").await.unwrap();
201 let result = cache.get("key1").await.unwrap();
202 assert_eq!(result.as_deref(), Some("second response"));
203 }
204}