1use sqlx::SqlitePool;
5use zeph_llm::provider::{Message, Role};
6
7use crate::error::MemoryError;
8
9fn role_to_str(role: Role) -> &'static str {
10 match role {
11 Role::System => "system",
12 Role::User => "user",
13 Role::Assistant => "assistant",
14 }
15}
16
17pub struct ResponseCache {
18 pool: SqlitePool,
19 ttl_secs: u64,
20}
21
22impl ResponseCache {
23 #[must_use]
24 pub fn new(pool: SqlitePool, ttl_secs: u64) -> Self {
25 Self { pool, ttl_secs }
26 }
27
28 pub async fn get(&self, key: &str) -> Result<Option<String>, MemoryError> {
34 let now = unix_now();
35 let row: Option<(String,)> = sqlx::query_as(
36 "SELECT response FROM response_cache WHERE cache_key = ? AND expires_at > ?",
37 )
38 .bind(key)
39 .bind(now)
40 .fetch_optional(&self.pool)
41 .await?;
42 Ok(row.map(|(r,)| r))
43 }
44
45 pub async fn put(&self, key: &str, response: &str, model: &str) -> Result<(), MemoryError> {
51 let now = unix_now();
52 let expires_at = now.saturating_add(self.ttl_secs.min(31_536_000).cast_signed());
54 sqlx::query(
55 "INSERT OR REPLACE INTO response_cache (cache_key, response, model, created_at, expires_at) \
56 VALUES (?, ?, ?, ?, ?)",
57 )
58 .bind(key)
59 .bind(response)
60 .bind(model)
61 .bind(now)
62 .bind(expires_at)
63 .execute(&self.pool)
64 .await?;
65 Ok(())
66 }
67
68 pub async fn cleanup_expired(&self) -> Result<u64, MemoryError> {
74 let now = unix_now();
75 let result = sqlx::query("DELETE FROM response_cache WHERE expires_at <= ?")
76 .bind(now)
77 .execute(&self.pool)
78 .await?;
79 Ok(result.rows_affected())
80 }
81
82 #[must_use]
84 pub fn compute_key(messages: &[Message], model: &str) -> String {
85 let mut hasher = blake3::Hasher::new();
86 for msg in messages {
87 let role = role_to_str(msg.role).as_bytes();
88 hasher.update(&(role.len() as u64).to_le_bytes());
89 hasher.update(role);
90 let content = msg.content.as_bytes();
91 hasher.update(&(content.len() as u64).to_le_bytes());
92 hasher.update(content);
93 }
94 let model_bytes = model.as_bytes();
95 hasher.update(&(model_bytes.len() as u64).to_le_bytes());
96 hasher.update(model_bytes);
97 hasher.finalize().to_hex().to_string()
98 }
99}
100
101fn unix_now() -> i64 {
102 std::time::SystemTime::now()
103 .duration_since(std::time::UNIX_EPOCH)
104 .unwrap_or_default()
105 .as_secs()
106 .cast_signed()
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use crate::sqlite::SqliteStore;
113 use zeph_llm::provider::MessageMetadata;
114
115 async fn test_cache() -> ResponseCache {
116 let store = SqliteStore::new(":memory:").await.unwrap();
117 ResponseCache::new(store.pool().clone(), 3600)
118 }
119
120 #[tokio::test]
121 async fn cache_miss_returns_none() {
122 let cache = test_cache().await;
123 let result = cache.get("nonexistent").await.unwrap();
124 assert!(result.is_none());
125 }
126
127 #[tokio::test]
128 async fn cache_put_and_get_roundtrip() {
129 let cache = test_cache().await;
130 cache.put("key1", "response text", "gpt-4").await.unwrap();
131 let result = cache.get("key1").await.unwrap();
132 assert_eq!(result.as_deref(), Some("response text"));
133 }
134
135 #[tokio::test]
136 async fn cache_expired_entry_returns_none() {
137 let store = SqliteStore::new(":memory:").await.unwrap();
138 let cache = ResponseCache::new(store.pool().clone(), 0);
139 cache.put("key1", "response", "model").await.unwrap();
141 let result = cache.get("key1").await.unwrap();
143 assert!(result.is_none());
144 }
145
146 #[tokio::test]
147 async fn cleanup_expired_removes_entries() {
148 let store = SqliteStore::new(":memory:").await.unwrap();
149 let cache = ResponseCache::new(store.pool().clone(), 0);
150 cache.put("key1", "response", "model").await.unwrap();
151 let deleted = cache.cleanup_expired().await.unwrap();
152 assert!(deleted > 0);
153 }
154
155 #[tokio::test]
156 async fn cleanup_does_not_remove_valid_entries() {
157 let cache = test_cache().await;
158 cache.put("key1", "response", "model").await.unwrap();
159 let deleted = cache.cleanup_expired().await.unwrap();
160 assert_eq!(deleted, 0);
161 let result = cache.get("key1").await.unwrap();
162 assert!(result.is_some());
163 }
164
165 #[test]
166 fn compute_key_deterministic() {
167 let msgs = vec![Message {
168 role: Role::User,
169 content: "hello".into(),
170 parts: vec![],
171 metadata: MessageMetadata::default(),
172 }];
173 let k1 = ResponseCache::compute_key(&msgs, "gpt-4");
174 let k2 = ResponseCache::compute_key(&msgs, "gpt-4");
175 assert_eq!(k1, k2);
176 }
177
178 #[test]
179 fn compute_key_different_for_different_content() {
180 let msgs1 = vec![Message {
181 role: Role::User,
182 content: "hello".into(),
183 parts: vec![],
184 metadata: MessageMetadata::default(),
185 }];
186 let msgs2 = vec![Message {
187 role: Role::User,
188 content: "world".into(),
189 parts: vec![],
190 metadata: MessageMetadata::default(),
191 }];
192 assert_ne!(
193 ResponseCache::compute_key(&msgs1, "gpt-4"),
194 ResponseCache::compute_key(&msgs2, "gpt-4")
195 );
196 }
197
198 #[test]
199 fn compute_key_different_for_different_model() {
200 let msgs = vec![Message {
201 role: Role::User,
202 content: "hello".into(),
203 parts: vec![],
204 metadata: MessageMetadata::default(),
205 }];
206 assert_ne!(
207 ResponseCache::compute_key(&msgs, "gpt-4"),
208 ResponseCache::compute_key(&msgs, "gpt-3.5")
209 );
210 }
211
212 #[test]
213 fn compute_key_empty_messages() {
214 let k = ResponseCache::compute_key(&[], "model");
215 assert!(!k.is_empty());
216 }
217
218 #[test]
219 fn compute_key_no_length_prefix_ambiguity() {
220 let msgs1 = vec![
223 Message {
224 role: Role::User,
225 content: "ab".into(),
226 parts: vec![],
227 metadata: MessageMetadata::default(),
228 },
229 Message {
230 role: Role::User,
231 content: "c".into(),
232 parts: vec![],
233 metadata: MessageMetadata::default(),
234 },
235 ];
236 let msgs2 = vec![
237 Message {
238 role: Role::User,
239 content: "a".into(),
240 parts: vec![],
241 metadata: MessageMetadata::default(),
242 },
243 Message {
244 role: Role::User,
245 content: "bc".into(),
246 parts: vec![],
247 metadata: MessageMetadata::default(),
248 },
249 ];
250 assert_ne!(
251 ResponseCache::compute_key(&msgs1, "model"),
252 ResponseCache::compute_key(&msgs2, "model")
253 );
254 }
255
256 #[tokio::test]
257 async fn ttl_extreme_value_does_not_overflow() {
258 let store = SqliteStore::new(":memory:").await.unwrap();
259 let cache = ResponseCache::new(store.pool().clone(), u64::MAX - 1);
261 cache.put("key1", "response", "model").await.unwrap();
263 let result = cache.get("key1").await.unwrap();
265 assert_eq!(result.as_deref(), Some("response"));
266 }
267
268 #[tokio::test]
269 async fn insert_or_replace_updates_existing_entry() {
270 let cache = test_cache().await;
271 cache.put("key1", "first response", "gpt-4").await.unwrap();
272 cache.put("key1", "second response", "gpt-4").await.unwrap();
273 let result = cache.get("key1").await.unwrap();
274 assert_eq!(result.as_deref(), Some("second response"));
275 }
276}