Skip to main content

zeph_memory/sqlite/
compression_guidelines.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! SQLite-backed store for ACON compression guidelines and failure pairs.
5
6use crate::error::MemoryError;
7use crate::sqlite::SqliteStore;
8use crate::types::ConversationId;
9
10/// A recorded compression failure pair: the compressed context and the response
11/// that indicated context was lost.
12#[derive(Debug, Clone)]
13pub struct CompressionFailurePair {
14    pub id: i64,
15    pub conversation_id: ConversationId,
16    pub compressed_context: String,
17    pub failure_reason: String,
18    pub created_at: String,
19}
20
21/// Maximum characters stored per `compressed_context` or `failure_reason` field.
22const MAX_FIELD_CHARS: usize = 4096;
23
24fn truncate_field(s: &str) -> &str {
25    let mut idx = MAX_FIELD_CHARS;
26    while idx > 0 && !s.is_char_boundary(idx) {
27        idx -= 1;
28    }
29    &s[..idx.min(s.len())]
30}
31
32impl SqliteStore {
33    /// Load the latest active compression guidelines (global scope).
34    ///
35    /// Returns `(version, guidelines_text)`. Returns `(0, "")` if no guidelines exist yet.
36    ///
37    /// # Errors
38    ///
39    /// Returns an error if the database query fails.
40    pub async fn load_compression_guidelines(&self) -> Result<(i64, String), MemoryError> {
41        let row = sqlx::query_as::<_, (i64, String)>(
42            "SELECT version, guidelines FROM compression_guidelines ORDER BY version DESC LIMIT 1",
43        )
44        .fetch_optional(&self.pool)
45        .await?;
46
47        Ok(row.unwrap_or((0, String::new())))
48    }
49
50    /// Save a new version of the compression guidelines (global scope).
51    ///
52    /// Inserts a new row; older versions are retained for audit.
53    /// Returns the new version number.
54    ///
55    /// # Errors
56    ///
57    /// Returns an error if the database insert fails.
58    pub async fn save_compression_guidelines(
59        &self,
60        guidelines: &str,
61        token_count: i64,
62    ) -> Result<i64, MemoryError> {
63        // The INSERT...SELECT computes MAX(version)+1 and inserts it in a single
64        // statement. SQLite's single-writer WAL guarantee makes this atomic —
65        // no concurrent writer can observe the same MAX and produce a duplicate version.
66        let new_version: i64 = sqlx::query_scalar(
67            "INSERT INTO compression_guidelines (version, guidelines, token_count) \
68             SELECT COALESCE(MAX(version), 0) + 1, ?, ? \
69             FROM compression_guidelines \
70             RETURNING version",
71        )
72        .bind(guidelines)
73        .bind(token_count)
74        .fetch_one(&self.pool)
75        .await?;
76        Ok(new_version)
77    }
78
79    /// Log a compression failure pair.
80    ///
81    /// Both `compressed_context` and `failure_reason` are truncated to 4096 chars.
82    /// Returns the inserted row id.
83    ///
84    /// # Errors
85    ///
86    /// Returns an error if the database insert fails.
87    pub async fn log_compression_failure(
88        &self,
89        conversation_id: ConversationId,
90        compressed_context: &str,
91        failure_reason: &str,
92    ) -> Result<i64, MemoryError> {
93        let ctx = truncate_field(compressed_context);
94        let reason = truncate_field(failure_reason);
95        let id = sqlx::query_scalar(
96            "INSERT INTO compression_failure_pairs \
97             (conversation_id, compressed_context, failure_reason) \
98             VALUES (?, ?, ?) RETURNING id",
99        )
100        .bind(conversation_id.0)
101        .bind(ctx)
102        .bind(reason)
103        .fetch_one(&self.pool)
104        .await?;
105        Ok(id)
106    }
107
108    /// Get unused failure pairs (oldest first), up to `limit`.
109    ///
110    /// # Errors
111    ///
112    /// Returns an error if the database query fails.
113    pub async fn get_unused_failure_pairs(
114        &self,
115        limit: usize,
116    ) -> Result<Vec<CompressionFailurePair>, MemoryError> {
117        let limit = i64::try_from(limit).unwrap_or(i64::MAX);
118        let rows = sqlx::query_as::<_, (i64, i64, String, String, String)>(
119            "SELECT id, conversation_id, compressed_context, failure_reason, created_at \
120             FROM compression_failure_pairs \
121             WHERE used_in_update = 0 \
122             ORDER BY created_at ASC \
123             LIMIT ?",
124        )
125        .bind(limit)
126        .fetch_all(&self.pool)
127        .await?;
128
129        Ok(rows
130            .into_iter()
131            .map(
132                |(id, cid, ctx, reason, created_at)| CompressionFailurePair {
133                    id,
134                    conversation_id: ConversationId(cid),
135                    compressed_context: ctx,
136                    failure_reason: reason,
137                    created_at,
138                },
139            )
140            .collect())
141    }
142
143    /// Mark failure pairs as consumed by the updater.
144    ///
145    /// # Errors
146    ///
147    /// Returns an error if the database update fails.
148    pub async fn mark_failure_pairs_used(&self, ids: &[i64]) -> Result<(), MemoryError> {
149        if ids.is_empty() {
150            return Ok(());
151        }
152        let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
153        let query = format!(
154            "UPDATE compression_failure_pairs SET used_in_update = 1 WHERE id IN ({placeholders})"
155        );
156        let mut q = sqlx::query(&query);
157        for id in ids {
158            q = q.bind(id);
159        }
160        q.execute(&self.pool).await?;
161        Ok(())
162    }
163
164    /// Count unused failure pairs.
165    ///
166    /// # Errors
167    ///
168    /// Returns an error if the database query fails.
169    pub async fn count_unused_failure_pairs(&self) -> Result<i64, MemoryError> {
170        let count = sqlx::query_scalar(
171            "SELECT COUNT(*) FROM compression_failure_pairs WHERE used_in_update = 0",
172        )
173        .fetch_one(&self.pool)
174        .await?;
175        Ok(count)
176    }
177
178    /// Delete old used failure pairs, keeping the most recent `keep_recent` unused pairs.
179    ///
180    /// Removes all rows where `used_in_update = 1`. Unused rows are managed by the
181    /// `max_stored_pairs` enforcement below: if there are more than `keep_recent` unused pairs,
182    /// the oldest excess rows are deleted.
183    ///
184    /// # Errors
185    ///
186    /// Returns an error if the database query fails.
187    pub async fn cleanup_old_failure_pairs(&self, keep_recent: usize) -> Result<(), MemoryError> {
188        // Delete all used pairs (they've already been processed).
189        sqlx::query("DELETE FROM compression_failure_pairs WHERE used_in_update = 1")
190            .execute(&self.pool)
191            .await?;
192
193        // Keep only the most recent `keep_recent` unused pairs.
194        let keep = i64::try_from(keep_recent).unwrap_or(i64::MAX);
195        sqlx::query(
196            "DELETE FROM compression_failure_pairs \
197             WHERE used_in_update = 0 \
198             AND id NOT IN ( \
199                 SELECT id FROM compression_failure_pairs \
200                 WHERE used_in_update = 0 \
201                 ORDER BY created_at DESC \
202                 LIMIT ? \
203             )",
204        )
205        .bind(keep)
206        .execute(&self.pool)
207        .await?;
208
209        Ok(())
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    // pool_size=1 is required: SQLite :memory: creates an isolated database per
218    // connection, so multiple connections would each see an empty schema.
219    async fn make_store() -> SqliteStore {
220        SqliteStore::with_pool_size(":memory:", 1)
221            .await
222            .expect("in-memory SqliteStore")
223    }
224
225    #[tokio::test]
226    async fn load_guidelines_returns_defaults_when_empty() {
227        let store = make_store().await;
228        let (version, text) = store.load_compression_guidelines().await.unwrap();
229        assert_eq!(version, 0);
230        assert!(text.is_empty());
231    }
232
233    #[tokio::test]
234    async fn save_and_load_guidelines() {
235        let store = make_store().await;
236        let v1 = store
237            .save_compression_guidelines("always preserve file paths", 4)
238            .await
239            .unwrap();
240        assert_eq!(v1, 1);
241        let v2 = store
242            .save_compression_guidelines("always preserve file paths\nalways preserve errors", 8)
243            .await
244            .unwrap();
245        assert_eq!(v2, 2);
246        // Loading should return the latest version.
247        let (v, text) = store.load_compression_guidelines().await.unwrap();
248        assert_eq!(v, 2);
249        assert!(text.contains("errors"));
250    }
251
252    #[tokio::test]
253    async fn log_and_count_failure_pairs() {
254        let store = make_store().await;
255        let cid = ConversationId(store.create_conversation().await.unwrap().0);
256        store
257            .log_compression_failure(cid, "compressed ctx", "i don't recall that")
258            .await
259            .unwrap();
260        let count = store.count_unused_failure_pairs().await.unwrap();
261        assert_eq!(count, 1);
262    }
263
264    #[tokio::test]
265    async fn get_unused_pairs_sorted_oldest_first() {
266        let store = make_store().await;
267        let cid = ConversationId(store.create_conversation().await.unwrap().0);
268        store
269            .log_compression_failure(cid, "ctx A", "reason A")
270            .await
271            .unwrap();
272        store
273            .log_compression_failure(cid, "ctx B", "reason B")
274            .await
275            .unwrap();
276        let pairs = store.get_unused_failure_pairs(10).await.unwrap();
277        assert_eq!(pairs.len(), 2);
278        assert_eq!(pairs[0].compressed_context, "ctx A");
279    }
280
281    #[tokio::test]
282    async fn mark_pairs_used_reduces_count() {
283        let store = make_store().await;
284        let cid = ConversationId(store.create_conversation().await.unwrap().0);
285        let id = store
286            .log_compression_failure(cid, "ctx", "reason")
287            .await
288            .unwrap();
289        store.mark_failure_pairs_used(&[id]).await.unwrap();
290        let count = store.count_unused_failure_pairs().await.unwrap();
291        assert_eq!(count, 0);
292    }
293
294    #[tokio::test]
295    async fn cleanup_deletes_used_and_trims_unused() {
296        let store = make_store().await;
297        let cid = ConversationId(store.create_conversation().await.unwrap().0);
298        // Add 3 pairs and mark 1 used.
299        let id1 = store
300            .log_compression_failure(cid, "ctx1", "r1")
301            .await
302            .unwrap();
303        store
304            .log_compression_failure(cid, "ctx2", "r2")
305            .await
306            .unwrap();
307        store
308            .log_compression_failure(cid, "ctx3", "r3")
309            .await
310            .unwrap();
311        store.mark_failure_pairs_used(&[id1]).await.unwrap();
312        // Cleanup: keep at most 1 unused.
313        store.cleanup_old_failure_pairs(1).await.unwrap();
314        let count = store.count_unused_failure_pairs().await.unwrap();
315        assert_eq!(count, 1, "only 1 unused pair should remain");
316    }
317
318    #[tokio::test]
319    async fn truncate_field_respects_char_boundary() {
320        let s = "а".repeat(5000); // Cyrillic 'а', 2 bytes each
321        let truncated = truncate_field(&s);
322        assert!(truncated.len() <= MAX_FIELD_CHARS);
323        assert!(s.is_char_boundary(truncated.len()));
324    }
325
326    #[tokio::test]
327    async fn unique_constraint_prevents_duplicate_version() {
328        let store = make_store().await;
329        // Insert version 1 via the public API.
330        store.save_compression_guidelines("first", 1).await.unwrap();
331        // store.pool() access is intentional: we need direct pool access to bypass
332        // the public API and test the UNIQUE constraint at the SQL level.
333        let result = sqlx::query(
334            "INSERT INTO compression_guidelines (version, guidelines, token_count) VALUES (1, 'dup', 0)",
335        )
336        .execute(store.pool())
337        .await;
338        assert!(
339            result.is_err(),
340            "duplicate version insert should violate UNIQUE constraint"
341        );
342    }
343
344    /// Concurrent saves must produce strictly unique versions with no collisions.
345    ///
346    /// Uses a file-backed database because SQLite `:memory:` creates an isolated
347    /// database per connection — a multi-connection pool over `:memory:` would give
348    /// each writer its own empty schema and cannot test shared-state atomicity.
349    #[tokio::test]
350    async fn concurrent_saves_produce_unique_versions() {
351        use std::collections::HashSet;
352        use std::sync::Arc;
353
354        let dir = tempfile::tempdir().expect("tempdir");
355        let db_path = dir.path().join("test.db");
356        let store = Arc::new(
357            SqliteStore::with_pool_size(db_path.to_str().expect("utf8 path"), 4)
358                .await
359                .expect("file-backed SqliteStore"),
360        );
361
362        let tasks: Vec<_> = (0..8_i64)
363            .map(|i| {
364                let s = Arc::clone(&store);
365                tokio::spawn(async move {
366                    s.save_compression_guidelines(&format!("guideline {i}"), i)
367                        .await
368                        .expect("concurrent save must succeed")
369                })
370            })
371            .collect();
372
373        let mut versions = HashSet::new();
374        for task in tasks {
375            let v = task.await.expect("task must not panic");
376            assert!(versions.insert(v), "version {v} appeared more than once");
377        }
378        assert_eq!(
379            versions.len(),
380            8,
381            "all 8 saves must produce distinct versions"
382        );
383    }
384}