Skip to main content

zeph_memory/
response_cache.rs

1use sqlx::SqlitePool;
2use zeph_llm::provider::{Message, Role};
3
4use crate::error::MemoryError;
5
6fn role_to_str(role: Role) -> &'static str {
7    match role {
8        Role::System => "system",
9        Role::User => "user",
10        Role::Assistant => "assistant",
11    }
12}
13
14pub struct ResponseCache {
15    pool: SqlitePool,
16    ttl_secs: u64,
17}
18
19impl ResponseCache {
20    #[must_use]
21    pub fn new(pool: SqlitePool, ttl_secs: u64) -> Self {
22        Self { pool, ttl_secs }
23    }
24
25    /// Look up a cached response by key. Returns `None` if not found or expired.
26    ///
27    /// # Errors
28    ///
29    /// Returns an error if the database query fails.
30    pub async fn get(&self, key: &str) -> Result<Option<String>, MemoryError> {
31        let now = unix_now();
32        let row: Option<(String,)> = sqlx::query_as(
33            "SELECT response FROM response_cache WHERE cache_key = ? AND expires_at > ?",
34        )
35        .bind(key)
36        .bind(now)
37        .fetch_optional(&self.pool)
38        .await?;
39        Ok(row.map(|(r,)| r))
40    }
41
42    /// Store a response in the cache with TTL.
43    ///
44    /// # Errors
45    ///
46    /// Returns an error if the database insert fails.
47    pub async fn put(&self, key: &str, response: &str, model: &str) -> Result<(), MemoryError> {
48        let now = unix_now();
49        // Cap TTL at 1 year (31_536_000 s) to prevent i64 overflow for extreme values.
50        let expires_at = now.saturating_add(self.ttl_secs.min(31_536_000).cast_signed());
51        sqlx::query(
52            "INSERT OR REPLACE INTO response_cache (cache_key, response, model, created_at, expires_at) \
53             VALUES (?, ?, ?, ?, ?)",
54        )
55        .bind(key)
56        .bind(response)
57        .bind(model)
58        .bind(now)
59        .bind(expires_at)
60        .execute(&self.pool)
61        .await?;
62        Ok(())
63    }
64
65    /// Delete expired cache entries. Returns the number of rows deleted.
66    ///
67    /// # Errors
68    ///
69    /// Returns an error if the database delete fails.
70    pub async fn cleanup_expired(&self) -> Result<u64, MemoryError> {
71        let now = unix_now();
72        let result = sqlx::query("DELETE FROM response_cache WHERE expires_at <= ?")
73            .bind(now)
74            .execute(&self.pool)
75            .await?;
76        Ok(result.rows_affected())
77    }
78
79    /// Compute a deterministic cache key from messages and model name using blake3.
80    #[must_use]
81    pub fn compute_key(messages: &[Message], model: &str) -> String {
82        let mut hasher = blake3::Hasher::new();
83        for msg in messages {
84            let role = role_to_str(msg.role).as_bytes();
85            hasher.update(&(role.len() as u64).to_le_bytes());
86            hasher.update(role);
87            let content = msg.content.as_bytes();
88            hasher.update(&(content.len() as u64).to_le_bytes());
89            hasher.update(content);
90        }
91        let model_bytes = model.as_bytes();
92        hasher.update(&(model_bytes.len() as u64).to_le_bytes());
93        hasher.update(model_bytes);
94        hasher.finalize().to_hex().to_string()
95    }
96}
97
98fn unix_now() -> i64 {
99    std::time::SystemTime::now()
100        .duration_since(std::time::UNIX_EPOCH)
101        .unwrap_or_default()
102        .as_secs()
103        .cast_signed()
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use crate::sqlite::SqliteStore;
110
111    async fn test_cache() -> ResponseCache {
112        let store = SqliteStore::new(":memory:").await.unwrap();
113        ResponseCache::new(store.pool().clone(), 3600)
114    }
115
116    #[tokio::test]
117    async fn cache_miss_returns_none() {
118        let cache = test_cache().await;
119        let result = cache.get("nonexistent").await.unwrap();
120        assert!(result.is_none());
121    }
122
123    #[tokio::test]
124    async fn cache_put_and_get_roundtrip() {
125        let cache = test_cache().await;
126        cache.put("key1", "response text", "gpt-4").await.unwrap();
127        let result = cache.get("key1").await.unwrap();
128        assert_eq!(result.as_deref(), Some("response text"));
129    }
130
131    #[tokio::test]
132    async fn cache_expired_entry_returns_none() {
133        let store = SqliteStore::new(":memory:").await.unwrap();
134        let cache = ResponseCache::new(store.pool().clone(), 0);
135        // ttl=0 means expires_at == now, which fails the > check
136        cache.put("key1", "response", "model").await.unwrap();
137        // Immediately expired (expires_at = now + 0 = now, query checks > now)
138        let result = cache.get("key1").await.unwrap();
139        assert!(result.is_none());
140    }
141
142    #[tokio::test]
143    async fn cleanup_expired_removes_entries() {
144        let store = SqliteStore::new(":memory:").await.unwrap();
145        let cache = ResponseCache::new(store.pool().clone(), 0);
146        cache.put("key1", "response", "model").await.unwrap();
147        let deleted = cache.cleanup_expired().await.unwrap();
148        assert!(deleted > 0);
149    }
150
151    #[tokio::test]
152    async fn cleanup_does_not_remove_valid_entries() {
153        let cache = test_cache().await;
154        cache.put("key1", "response", "model").await.unwrap();
155        let deleted = cache.cleanup_expired().await.unwrap();
156        assert_eq!(deleted, 0);
157        let result = cache.get("key1").await.unwrap();
158        assert!(result.is_some());
159    }
160
161    #[test]
162    fn compute_key_deterministic() {
163        let msgs = vec![Message {
164            role: Role::User,
165            content: "hello".into(),
166            parts: vec![],
167        }];
168        let k1 = ResponseCache::compute_key(&msgs, "gpt-4");
169        let k2 = ResponseCache::compute_key(&msgs, "gpt-4");
170        assert_eq!(k1, k2);
171    }
172
173    #[test]
174    fn compute_key_different_for_different_content() {
175        let msgs1 = vec![Message {
176            role: Role::User,
177            content: "hello".into(),
178            parts: vec![],
179        }];
180        let msgs2 = vec![Message {
181            role: Role::User,
182            content: "world".into(),
183            parts: vec![],
184        }];
185        assert_ne!(
186            ResponseCache::compute_key(&msgs1, "gpt-4"),
187            ResponseCache::compute_key(&msgs2, "gpt-4")
188        );
189    }
190
191    #[test]
192    fn compute_key_different_for_different_model() {
193        let msgs = vec![Message {
194            role: Role::User,
195            content: "hello".into(),
196            parts: vec![],
197        }];
198        assert_ne!(
199            ResponseCache::compute_key(&msgs, "gpt-4"),
200            ResponseCache::compute_key(&msgs, "gpt-3.5")
201        );
202    }
203
204    #[test]
205    fn compute_key_empty_messages() {
206        let k = ResponseCache::compute_key(&[], "model");
207        assert!(!k.is_empty());
208    }
209
210    #[test]
211    fn compute_key_no_length_prefix_ambiguity() {
212        // Without length-prefix, "ab"+"c" and "a"+"bc" would hash identically.
213        // With proper length-prefixing they must differ.
214        let msgs1 = vec![
215            Message {
216                role: Role::User,
217                content: "ab".into(),
218                parts: vec![],
219            },
220            Message {
221                role: Role::User,
222                content: "c".into(),
223                parts: vec![],
224            },
225        ];
226        let msgs2 = vec![
227            Message {
228                role: Role::User,
229                content: "a".into(),
230                parts: vec![],
231            },
232            Message {
233                role: Role::User,
234                content: "bc".into(),
235                parts: vec![],
236            },
237        ];
238        assert_ne!(
239            ResponseCache::compute_key(&msgs1, "model"),
240            ResponseCache::compute_key(&msgs2, "model")
241        );
242    }
243
244    #[tokio::test]
245    async fn ttl_extreme_value_does_not_overflow() {
246        let store = SqliteStore::new(":memory:").await.unwrap();
247        // Use u64::MAX - 1 as TTL; without capping this would overflow i64.
248        let cache = ResponseCache::new(store.pool().clone(), u64::MAX - 1);
249        // Should not panic or produce a negative expires_at.
250        cache.put("key1", "response", "model").await.unwrap();
251        // Entry should be retrievable (far-future expiry).
252        let result = cache.get("key1").await.unwrap();
253        assert_eq!(result.as_deref(), Some("response"));
254    }
255
256    #[tokio::test]
257    async fn insert_or_replace_updates_existing_entry() {
258        let cache = test_cache().await;
259        cache.put("key1", "first response", "gpt-4").await.unwrap();
260        cache.put("key1", "second response", "gpt-4").await.unwrap();
261        let result = cache.get("key1").await.unwrap();
262        assert_eq!(result.as_deref(), Some("second response"));
263    }
264}