Skip to main content

zeph_memory/
response_cache.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use sqlx::SqlitePool;
5
6use crate::error::MemoryError;
7
8pub struct ResponseCache {
9    pool: SqlitePool,
10    ttl_secs: u64,
11}
12
13impl ResponseCache {
14    #[must_use]
15    pub fn new(pool: SqlitePool, ttl_secs: u64) -> Self {
16        Self { pool, ttl_secs }
17    }
18
19    /// Look up a cached response by key. Returns `None` if not found or expired.
20    ///
21    /// # Errors
22    ///
23    /// Returns an error if the database query fails.
24    pub async fn get(&self, key: &str) -> Result<Option<String>, MemoryError> {
25        let now = unix_now();
26        let row: Option<(String,)> = sqlx::query_as(
27            "SELECT response FROM response_cache WHERE cache_key = ? AND expires_at > ?",
28        )
29        .bind(key)
30        .bind(now)
31        .fetch_optional(&self.pool)
32        .await?;
33        Ok(row.map(|(r,)| r))
34    }
35
36    /// Store a response in the cache with TTL.
37    ///
38    /// # Errors
39    ///
40    /// Returns an error if the database insert fails.
41    pub async fn put(&self, key: &str, response: &str, model: &str) -> Result<(), MemoryError> {
42        let now = unix_now();
43        // Cap TTL at 1 year (31_536_000 s) to prevent i64 overflow for extreme values.
44        let expires_at = now.saturating_add(self.ttl_secs.min(31_536_000).cast_signed());
45        sqlx::query(
46            "INSERT OR REPLACE INTO response_cache (cache_key, response, model, created_at, expires_at) \
47             VALUES (?, ?, ?, ?, ?)",
48        )
49        .bind(key)
50        .bind(response)
51        .bind(model)
52        .bind(now)
53        .bind(expires_at)
54        .execute(&self.pool)
55        .await?;
56        Ok(())
57    }
58
59    /// Delete expired cache entries. Returns the number of rows deleted.
60    ///
61    /// # Errors
62    ///
63    /// Returns an error if the database delete fails.
64    pub async fn cleanup_expired(&self) -> Result<u64, MemoryError> {
65        let now = unix_now();
66        let result = sqlx::query("DELETE FROM response_cache WHERE expires_at <= ?")
67            .bind(now)
68            .execute(&self.pool)
69            .await?;
70        Ok(result.rows_affected())
71    }
72
73    /// Compute a deterministic cache key from the last user message and model name using blake3.
74    ///
75    /// The key intentionally ignores conversation history so that identical user messages
76    /// produce cache hits regardless of what preceded them. This is the desired behavior for
77    /// a short-TTL response cache, but it means context-dependent questions (e.g. "Explain
78    /// this") may return a cached response from a different context. The TTL bounds staleness.
79    #[must_use]
80    pub fn compute_key(last_user_message: &str, model: &str) -> String {
81        let mut hasher = blake3::Hasher::new();
82        let content = last_user_message.as_bytes();
83        hasher.update(&(content.len() as u64).to_le_bytes());
84        hasher.update(content);
85        let model_bytes = model.as_bytes();
86        hasher.update(&(model_bytes.len() as u64).to_le_bytes());
87        hasher.update(model_bytes);
88        hasher.finalize().to_hex().to_string()
89    }
90}
91
92fn unix_now() -> i64 {
93    std::time::SystemTime::now()
94        .duration_since(std::time::UNIX_EPOCH)
95        .unwrap_or_default()
96        .as_secs()
97        .cast_signed()
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use crate::sqlite::SqliteStore;
104
105    async fn test_cache() -> ResponseCache {
106        let store = SqliteStore::new(":memory:").await.unwrap();
107        ResponseCache::new(store.pool().clone(), 3600)
108    }
109
110    #[tokio::test]
111    async fn cache_miss_returns_none() {
112        let cache = test_cache().await;
113        let result = cache.get("nonexistent").await.unwrap();
114        assert!(result.is_none());
115    }
116
117    #[tokio::test]
118    async fn cache_put_and_get_roundtrip() {
119        let cache = test_cache().await;
120        cache.put("key1", "response text", "gpt-4").await.unwrap();
121        let result = cache.get("key1").await.unwrap();
122        assert_eq!(result.as_deref(), Some("response text"));
123    }
124
125    #[tokio::test]
126    async fn cache_expired_entry_returns_none() {
127        let store = SqliteStore::new(":memory:").await.unwrap();
128        let cache = ResponseCache::new(store.pool().clone(), 0);
129        // ttl=0 means expires_at == now, which fails the > check
130        cache.put("key1", "response", "model").await.unwrap();
131        // Immediately expired (expires_at = now + 0 = now, query checks > now)
132        let result = cache.get("key1").await.unwrap();
133        assert!(result.is_none());
134    }
135
136    #[tokio::test]
137    async fn cleanup_expired_removes_entries() {
138        let store = SqliteStore::new(":memory:").await.unwrap();
139        let cache = ResponseCache::new(store.pool().clone(), 0);
140        cache.put("key1", "response", "model").await.unwrap();
141        let deleted = cache.cleanup_expired().await.unwrap();
142        assert!(deleted > 0);
143    }
144
145    #[tokio::test]
146    async fn cleanup_does_not_remove_valid_entries() {
147        let cache = test_cache().await;
148        cache.put("key1", "response", "model").await.unwrap();
149        let deleted = cache.cleanup_expired().await.unwrap();
150        assert_eq!(deleted, 0);
151        let result = cache.get("key1").await.unwrap();
152        assert!(result.is_some());
153    }
154
155    #[test]
156    fn compute_key_deterministic() {
157        let k1 = ResponseCache::compute_key("hello", "gpt-4");
158        let k2 = ResponseCache::compute_key("hello", "gpt-4");
159        assert_eq!(k1, k2);
160    }
161
162    #[test]
163    fn compute_key_different_for_different_content() {
164        assert_ne!(
165            ResponseCache::compute_key("hello", "gpt-4"),
166            ResponseCache::compute_key("world", "gpt-4")
167        );
168    }
169
170    #[test]
171    fn compute_key_different_for_different_model() {
172        assert_ne!(
173            ResponseCache::compute_key("hello", "gpt-4"),
174            ResponseCache::compute_key("hello", "gpt-3.5")
175        );
176    }
177
178    #[test]
179    fn compute_key_empty_message() {
180        let k = ResponseCache::compute_key("", "model");
181        assert!(!k.is_empty());
182    }
183
184    #[tokio::test]
185    async fn ttl_extreme_value_does_not_overflow() {
186        let store = SqliteStore::new(":memory:").await.unwrap();
187        // Use u64::MAX - 1 as TTL; without capping this would overflow i64.
188        let cache = ResponseCache::new(store.pool().clone(), u64::MAX - 1);
189        // Should not panic or produce a negative expires_at.
190        cache.put("key1", "response", "model").await.unwrap();
191        // Entry should be retrievable (far-future expiry).
192        let result = cache.get("key1").await.unwrap();
193        assert_eq!(result.as_deref(), Some("response"));
194    }
195
196    #[tokio::test]
197    async fn insert_or_replace_updates_existing_entry() {
198        let cache = test_cache().await;
199        cache.put("key1", "first response", "gpt-4").await.unwrap();
200        cache.put("key1", "second response", "gpt-4").await.unwrap();
201        let result = cache.get("key1").await.unwrap();
202        assert_eq!(result.as_deref(), Some("second response"));
203    }
204}