Skip to main content

zeph_memory/
response_cache.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use sqlx::SqlitePool;
5use zeph_llm::provider::{Message, Role};
6
7use crate::error::MemoryError;
8
9fn role_to_str(role: Role) -> &'static str {
10    match role {
11        Role::System => "system",
12        Role::User => "user",
13        Role::Assistant => "assistant",
14    }
15}
16
17pub struct ResponseCache {
18    pool: SqlitePool,
19    ttl_secs: u64,
20}
21
22impl ResponseCache {
23    #[must_use]
24    pub fn new(pool: SqlitePool, ttl_secs: u64) -> Self {
25        Self { pool, ttl_secs }
26    }
27
28    /// Look up a cached response by key. Returns `None` if not found or expired.
29    ///
30    /// # Errors
31    ///
32    /// Returns an error if the database query fails.
33    pub async fn get(&self, key: &str) -> Result<Option<String>, MemoryError> {
34        let now = unix_now();
35        let row: Option<(String,)> = sqlx::query_as(
36            "SELECT response FROM response_cache WHERE cache_key = ? AND expires_at > ?",
37        )
38        .bind(key)
39        .bind(now)
40        .fetch_optional(&self.pool)
41        .await?;
42        Ok(row.map(|(r,)| r))
43    }
44
45    /// Store a response in the cache with TTL.
46    ///
47    /// # Errors
48    ///
49    /// Returns an error if the database insert fails.
50    pub async fn put(&self, key: &str, response: &str, model: &str) -> Result<(), MemoryError> {
51        let now = unix_now();
52        // Cap TTL at 1 year (31_536_000 s) to prevent i64 overflow for extreme values.
53        let expires_at = now.saturating_add(self.ttl_secs.min(31_536_000).cast_signed());
54        sqlx::query(
55            "INSERT OR REPLACE INTO response_cache (cache_key, response, model, created_at, expires_at) \
56             VALUES (?, ?, ?, ?, ?)",
57        )
58        .bind(key)
59        .bind(response)
60        .bind(model)
61        .bind(now)
62        .bind(expires_at)
63        .execute(&self.pool)
64        .await?;
65        Ok(())
66    }
67
68    /// Delete expired cache entries. Returns the number of rows deleted.
69    ///
70    /// # Errors
71    ///
72    /// Returns an error if the database delete fails.
73    pub async fn cleanup_expired(&self) -> Result<u64, MemoryError> {
74        let now = unix_now();
75        let result = sqlx::query("DELETE FROM response_cache WHERE expires_at <= ?")
76            .bind(now)
77            .execute(&self.pool)
78            .await?;
79        Ok(result.rows_affected())
80    }
81
82    /// Compute a deterministic cache key from messages and model name using blake3.
83    #[must_use]
84    pub fn compute_key(messages: &[Message], model: &str) -> String {
85        let mut hasher = blake3::Hasher::new();
86        for msg in messages {
87            let role = role_to_str(msg.role).as_bytes();
88            hasher.update(&(role.len() as u64).to_le_bytes());
89            hasher.update(role);
90            let content = msg.content.as_bytes();
91            hasher.update(&(content.len() as u64).to_le_bytes());
92            hasher.update(content);
93        }
94        let model_bytes = model.as_bytes();
95        hasher.update(&(model_bytes.len() as u64).to_le_bytes());
96        hasher.update(model_bytes);
97        hasher.finalize().to_hex().to_string()
98    }
99}
100
101fn unix_now() -> i64 {
102    std::time::SystemTime::now()
103        .duration_since(std::time::UNIX_EPOCH)
104        .unwrap_or_default()
105        .as_secs()
106        .cast_signed()
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use crate::sqlite::SqliteStore;
113    use zeph_llm::provider::MessageMetadata;
114
115    async fn test_cache() -> ResponseCache {
116        let store = SqliteStore::new(":memory:").await.unwrap();
117        ResponseCache::new(store.pool().clone(), 3600)
118    }
119
120    #[tokio::test]
121    async fn cache_miss_returns_none() {
122        let cache = test_cache().await;
123        let result = cache.get("nonexistent").await.unwrap();
124        assert!(result.is_none());
125    }
126
127    #[tokio::test]
128    async fn cache_put_and_get_roundtrip() {
129        let cache = test_cache().await;
130        cache.put("key1", "response text", "gpt-4").await.unwrap();
131        let result = cache.get("key1").await.unwrap();
132        assert_eq!(result.as_deref(), Some("response text"));
133    }
134
135    #[tokio::test]
136    async fn cache_expired_entry_returns_none() {
137        let store = SqliteStore::new(":memory:").await.unwrap();
138        let cache = ResponseCache::new(store.pool().clone(), 0);
139        // ttl=0 means expires_at == now, which fails the > check
140        cache.put("key1", "response", "model").await.unwrap();
141        // Immediately expired (expires_at = now + 0 = now, query checks > now)
142        let result = cache.get("key1").await.unwrap();
143        assert!(result.is_none());
144    }
145
146    #[tokio::test]
147    async fn cleanup_expired_removes_entries() {
148        let store = SqliteStore::new(":memory:").await.unwrap();
149        let cache = ResponseCache::new(store.pool().clone(), 0);
150        cache.put("key1", "response", "model").await.unwrap();
151        let deleted = cache.cleanup_expired().await.unwrap();
152        assert!(deleted > 0);
153    }
154
155    #[tokio::test]
156    async fn cleanup_does_not_remove_valid_entries() {
157        let cache = test_cache().await;
158        cache.put("key1", "response", "model").await.unwrap();
159        let deleted = cache.cleanup_expired().await.unwrap();
160        assert_eq!(deleted, 0);
161        let result = cache.get("key1").await.unwrap();
162        assert!(result.is_some());
163    }
164
165    #[test]
166    fn compute_key_deterministic() {
167        let msgs = vec![Message {
168            role: Role::User,
169            content: "hello".into(),
170            parts: vec![],
171            metadata: MessageMetadata::default(),
172        }];
173        let k1 = ResponseCache::compute_key(&msgs, "gpt-4");
174        let k2 = ResponseCache::compute_key(&msgs, "gpt-4");
175        assert_eq!(k1, k2);
176    }
177
178    #[test]
179    fn compute_key_different_for_different_content() {
180        let msgs1 = vec![Message {
181            role: Role::User,
182            content: "hello".into(),
183            parts: vec![],
184            metadata: MessageMetadata::default(),
185        }];
186        let msgs2 = vec![Message {
187            role: Role::User,
188            content: "world".into(),
189            parts: vec![],
190            metadata: MessageMetadata::default(),
191        }];
192        assert_ne!(
193            ResponseCache::compute_key(&msgs1, "gpt-4"),
194            ResponseCache::compute_key(&msgs2, "gpt-4")
195        );
196    }
197
198    #[test]
199    fn compute_key_different_for_different_model() {
200        let msgs = vec![Message {
201            role: Role::User,
202            content: "hello".into(),
203            parts: vec![],
204            metadata: MessageMetadata::default(),
205        }];
206        assert_ne!(
207            ResponseCache::compute_key(&msgs, "gpt-4"),
208            ResponseCache::compute_key(&msgs, "gpt-3.5")
209        );
210    }
211
212    #[test]
213    fn compute_key_empty_messages() {
214        let k = ResponseCache::compute_key(&[], "model");
215        assert!(!k.is_empty());
216    }
217
218    #[test]
219    fn compute_key_no_length_prefix_ambiguity() {
220        // Without length-prefix, "ab"+"c" and "a"+"bc" would hash identically.
221        // With proper length-prefixing they must differ.
222        let msgs1 = vec![
223            Message {
224                role: Role::User,
225                content: "ab".into(),
226                parts: vec![],
227                metadata: MessageMetadata::default(),
228            },
229            Message {
230                role: Role::User,
231                content: "c".into(),
232                parts: vec![],
233                metadata: MessageMetadata::default(),
234            },
235        ];
236        let msgs2 = vec![
237            Message {
238                role: Role::User,
239                content: "a".into(),
240                parts: vec![],
241                metadata: MessageMetadata::default(),
242            },
243            Message {
244                role: Role::User,
245                content: "bc".into(),
246                parts: vec![],
247                metadata: MessageMetadata::default(),
248            },
249        ];
250        assert_ne!(
251            ResponseCache::compute_key(&msgs1, "model"),
252            ResponseCache::compute_key(&msgs2, "model")
253        );
254    }
255
256    #[tokio::test]
257    async fn ttl_extreme_value_does_not_overflow() {
258        let store = SqliteStore::new(":memory:").await.unwrap();
259        // Use u64::MAX - 1 as TTL; without capping this would overflow i64.
260        let cache = ResponseCache::new(store.pool().clone(), u64::MAX - 1);
261        // Should not panic or produce a negative expires_at.
262        cache.put("key1", "response", "model").await.unwrap();
263        // Entry should be retrievable (far-future expiry).
264        let result = cache.get("key1").await.unwrap();
265        assert_eq!(result.as_deref(), Some("response"));
266    }
267
268    #[tokio::test]
269    async fn insert_or_replace_updates_existing_entry() {
270        let cache = test_cache().await;
271        cache.put("key1", "first response", "gpt-4").await.unwrap();
272        cache.put("key1", "second response", "gpt-4").await.unwrap();
273        let result = cache.get("key1").await.unwrap();
274        assert_eq!(result.as_deref(), Some("second response"));
275    }
276}