1use sqlx::SqlitePool;
2use zeph_llm::provider::{Message, Role};
3
4use crate::error::MemoryError;
5
6fn role_to_str(role: Role) -> &'static str {
7 match role {
8 Role::System => "system",
9 Role::User => "user",
10 Role::Assistant => "assistant",
11 }
12}
13
14pub struct ResponseCache {
15 pool: SqlitePool,
16 ttl_secs: u64,
17}
18
19impl ResponseCache {
20 #[must_use]
21 pub fn new(pool: SqlitePool, ttl_secs: u64) -> Self {
22 Self { pool, ttl_secs }
23 }
24
25 pub async fn get(&self, key: &str) -> Result<Option<String>, MemoryError> {
31 let now = unix_now();
32 let row: Option<(String,)> = sqlx::query_as(
33 "SELECT response FROM response_cache WHERE cache_key = ? AND expires_at > ?",
34 )
35 .bind(key)
36 .bind(now)
37 .fetch_optional(&self.pool)
38 .await?;
39 Ok(row.map(|(r,)| r))
40 }
41
42 pub async fn put(&self, key: &str, response: &str, model: &str) -> Result<(), MemoryError> {
48 let now = unix_now();
49 let expires_at = now.saturating_add(self.ttl_secs.min(31_536_000).cast_signed());
51 sqlx::query(
52 "INSERT OR REPLACE INTO response_cache (cache_key, response, model, created_at, expires_at) \
53 VALUES (?, ?, ?, ?, ?)",
54 )
55 .bind(key)
56 .bind(response)
57 .bind(model)
58 .bind(now)
59 .bind(expires_at)
60 .execute(&self.pool)
61 .await?;
62 Ok(())
63 }
64
65 pub async fn cleanup_expired(&self) -> Result<u64, MemoryError> {
71 let now = unix_now();
72 let result = sqlx::query("DELETE FROM response_cache WHERE expires_at <= ?")
73 .bind(now)
74 .execute(&self.pool)
75 .await?;
76 Ok(result.rows_affected())
77 }
78
79 #[must_use]
81 pub fn compute_key(messages: &[Message], model: &str) -> String {
82 let mut hasher = blake3::Hasher::new();
83 for msg in messages {
84 let role = role_to_str(msg.role).as_bytes();
85 hasher.update(&(role.len() as u64).to_le_bytes());
86 hasher.update(role);
87 let content = msg.content.as_bytes();
88 hasher.update(&(content.len() as u64).to_le_bytes());
89 hasher.update(content);
90 }
91 let model_bytes = model.as_bytes();
92 hasher.update(&(model_bytes.len() as u64).to_le_bytes());
93 hasher.update(model_bytes);
94 hasher.finalize().to_hex().to_string()
95 }
96}
97
98fn unix_now() -> i64 {
99 std::time::SystemTime::now()
100 .duration_since(std::time::UNIX_EPOCH)
101 .unwrap_or_default()
102 .as_secs()
103 .cast_signed()
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use crate::sqlite::SqliteStore;
110
111 async fn test_cache() -> ResponseCache {
112 let store = SqliteStore::new(":memory:").await.unwrap();
113 ResponseCache::new(store.pool().clone(), 3600)
114 }
115
116 #[tokio::test]
117 async fn cache_miss_returns_none() {
118 let cache = test_cache().await;
119 let result = cache.get("nonexistent").await.unwrap();
120 assert!(result.is_none());
121 }
122
123 #[tokio::test]
124 async fn cache_put_and_get_roundtrip() {
125 let cache = test_cache().await;
126 cache.put("key1", "response text", "gpt-4").await.unwrap();
127 let result = cache.get("key1").await.unwrap();
128 assert_eq!(result.as_deref(), Some("response text"));
129 }
130
131 #[tokio::test]
132 async fn cache_expired_entry_returns_none() {
133 let store = SqliteStore::new(":memory:").await.unwrap();
134 let cache = ResponseCache::new(store.pool().clone(), 0);
135 cache.put("key1", "response", "model").await.unwrap();
137 let result = cache.get("key1").await.unwrap();
139 assert!(result.is_none());
140 }
141
142 #[tokio::test]
143 async fn cleanup_expired_removes_entries() {
144 let store = SqliteStore::new(":memory:").await.unwrap();
145 let cache = ResponseCache::new(store.pool().clone(), 0);
146 cache.put("key1", "response", "model").await.unwrap();
147 let deleted = cache.cleanup_expired().await.unwrap();
148 assert!(deleted > 0);
149 }
150
151 #[tokio::test]
152 async fn cleanup_does_not_remove_valid_entries() {
153 let cache = test_cache().await;
154 cache.put("key1", "response", "model").await.unwrap();
155 let deleted = cache.cleanup_expired().await.unwrap();
156 assert_eq!(deleted, 0);
157 let result = cache.get("key1").await.unwrap();
158 assert!(result.is_some());
159 }
160
161 #[test]
162 fn compute_key_deterministic() {
163 let msgs = vec![Message {
164 role: Role::User,
165 content: "hello".into(),
166 parts: vec![],
167 }];
168 let k1 = ResponseCache::compute_key(&msgs, "gpt-4");
169 let k2 = ResponseCache::compute_key(&msgs, "gpt-4");
170 assert_eq!(k1, k2);
171 }
172
173 #[test]
174 fn compute_key_different_for_different_content() {
175 let msgs1 = vec![Message {
176 role: Role::User,
177 content: "hello".into(),
178 parts: vec![],
179 }];
180 let msgs2 = vec![Message {
181 role: Role::User,
182 content: "world".into(),
183 parts: vec![],
184 }];
185 assert_ne!(
186 ResponseCache::compute_key(&msgs1, "gpt-4"),
187 ResponseCache::compute_key(&msgs2, "gpt-4")
188 );
189 }
190
191 #[test]
192 fn compute_key_different_for_different_model() {
193 let msgs = vec![Message {
194 role: Role::User,
195 content: "hello".into(),
196 parts: vec![],
197 }];
198 assert_ne!(
199 ResponseCache::compute_key(&msgs, "gpt-4"),
200 ResponseCache::compute_key(&msgs, "gpt-3.5")
201 );
202 }
203
204 #[test]
205 fn compute_key_empty_messages() {
206 let k = ResponseCache::compute_key(&[], "model");
207 assert!(!k.is_empty());
208 }
209
210 #[test]
211 fn compute_key_no_length_prefix_ambiguity() {
212 let msgs1 = vec![
215 Message {
216 role: Role::User,
217 content: "ab".into(),
218 parts: vec![],
219 },
220 Message {
221 role: Role::User,
222 content: "c".into(),
223 parts: vec![],
224 },
225 ];
226 let msgs2 = vec![
227 Message {
228 role: Role::User,
229 content: "a".into(),
230 parts: vec![],
231 },
232 Message {
233 role: Role::User,
234 content: "bc".into(),
235 parts: vec![],
236 },
237 ];
238 assert_ne!(
239 ResponseCache::compute_key(&msgs1, "model"),
240 ResponseCache::compute_key(&msgs2, "model")
241 );
242 }
243
244 #[tokio::test]
245 async fn ttl_extreme_value_does_not_overflow() {
246 let store = SqliteStore::new(":memory:").await.unwrap();
247 let cache = ResponseCache::new(store.pool().clone(), u64::MAX - 1);
249 cache.put("key1", "response", "model").await.unwrap();
251 let result = cache.get("key1").await.unwrap();
253 assert_eq!(result.as_deref(), Some("response"));
254 }
255
256 #[tokio::test]
257 async fn insert_or_replace_updates_existing_entry() {
258 let cache = test_cache().await;
259 cache.put("key1", "first response", "gpt-4").await.unwrap();
260 cache.put("key1", "second response", "gpt-4").await.unwrap();
261 let result = cache.get("key1").await.unwrap();
262 assert_eq!(result.as_deref(), Some("second response"));
263 }
264}