1use crate::error::MemoryError;
7use crate::sqlite::SqliteStore;
8use crate::types::ConversationId;
9
10#[derive(Debug, Clone)]
13pub struct CompressionFailurePair {
14 pub id: i64,
15 pub conversation_id: ConversationId,
16 pub compressed_context: String,
17 pub failure_reason: String,
18 pub created_at: String,
19}
20
21const MAX_FIELD_CHARS: usize = 4096;
23
24fn truncate_field(s: &str) -> &str {
25 let mut idx = MAX_FIELD_CHARS;
26 while idx > 0 && !s.is_char_boundary(idx) {
27 idx -= 1;
28 }
29 &s[..idx.min(s.len())]
30}
31
32impl SqliteStore {
33 pub async fn load_compression_guidelines(&self) -> Result<(i64, String), MemoryError> {
41 let row = sqlx::query_as::<_, (i64, String)>(
42 "SELECT version, guidelines FROM compression_guidelines ORDER BY version DESC LIMIT 1",
43 )
44 .fetch_optional(&self.pool)
45 .await?;
46
47 Ok(row.unwrap_or((0, String::new())))
48 }
49
50 pub async fn save_compression_guidelines(
59 &self,
60 guidelines: &str,
61 token_count: i64,
62 ) -> Result<i64, MemoryError> {
63 let new_version: i64 = sqlx::query_scalar(
67 "INSERT INTO compression_guidelines (version, guidelines, token_count) \
68 SELECT COALESCE(MAX(version), 0) + 1, ?, ? \
69 FROM compression_guidelines \
70 RETURNING version",
71 )
72 .bind(guidelines)
73 .bind(token_count)
74 .fetch_one(&self.pool)
75 .await?;
76 Ok(new_version)
77 }
78
79 pub async fn log_compression_failure(
88 &self,
89 conversation_id: ConversationId,
90 compressed_context: &str,
91 failure_reason: &str,
92 ) -> Result<i64, MemoryError> {
93 let ctx = truncate_field(compressed_context);
94 let reason = truncate_field(failure_reason);
95 let id = sqlx::query_scalar(
96 "INSERT INTO compression_failure_pairs \
97 (conversation_id, compressed_context, failure_reason) \
98 VALUES (?, ?, ?) RETURNING id",
99 )
100 .bind(conversation_id.0)
101 .bind(ctx)
102 .bind(reason)
103 .fetch_one(&self.pool)
104 .await?;
105 Ok(id)
106 }
107
108 pub async fn get_unused_failure_pairs(
114 &self,
115 limit: usize,
116 ) -> Result<Vec<CompressionFailurePair>, MemoryError> {
117 let limit = i64::try_from(limit).unwrap_or(i64::MAX);
118 let rows = sqlx::query_as::<_, (i64, i64, String, String, String)>(
119 "SELECT id, conversation_id, compressed_context, failure_reason, created_at \
120 FROM compression_failure_pairs \
121 WHERE used_in_update = 0 \
122 ORDER BY created_at ASC \
123 LIMIT ?",
124 )
125 .bind(limit)
126 .fetch_all(&self.pool)
127 .await?;
128
129 Ok(rows
130 .into_iter()
131 .map(
132 |(id, cid, ctx, reason, created_at)| CompressionFailurePair {
133 id,
134 conversation_id: ConversationId(cid),
135 compressed_context: ctx,
136 failure_reason: reason,
137 created_at,
138 },
139 )
140 .collect())
141 }
142
143 pub async fn mark_failure_pairs_used(&self, ids: &[i64]) -> Result<(), MemoryError> {
149 if ids.is_empty() {
150 return Ok(());
151 }
152 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
153 let query = format!(
154 "UPDATE compression_failure_pairs SET used_in_update = 1 WHERE id IN ({placeholders})"
155 );
156 let mut q = sqlx::query(&query);
157 for id in ids {
158 q = q.bind(id);
159 }
160 q.execute(&self.pool).await?;
161 Ok(())
162 }
163
164 pub async fn count_unused_failure_pairs(&self) -> Result<i64, MemoryError> {
170 let count = sqlx::query_scalar(
171 "SELECT COUNT(*) FROM compression_failure_pairs WHERE used_in_update = 0",
172 )
173 .fetch_one(&self.pool)
174 .await?;
175 Ok(count)
176 }
177
178 pub async fn cleanup_old_failure_pairs(&self, keep_recent: usize) -> Result<(), MemoryError> {
188 sqlx::query("DELETE FROM compression_failure_pairs WHERE used_in_update = 1")
190 .execute(&self.pool)
191 .await?;
192
193 let keep = i64::try_from(keep_recent).unwrap_or(i64::MAX);
195 sqlx::query(
196 "DELETE FROM compression_failure_pairs \
197 WHERE used_in_update = 0 \
198 AND id NOT IN ( \
199 SELECT id FROM compression_failure_pairs \
200 WHERE used_in_update = 0 \
201 ORDER BY created_at DESC \
202 LIMIT ? \
203 )",
204 )
205 .bind(keep)
206 .execute(&self.pool)
207 .await?;
208
209 Ok(())
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216
217 async fn make_store() -> SqliteStore {
220 SqliteStore::with_pool_size(":memory:", 1)
221 .await
222 .expect("in-memory SqliteStore")
223 }
224
225 #[tokio::test]
226 async fn load_guidelines_returns_defaults_when_empty() {
227 let store = make_store().await;
228 let (version, text) = store.load_compression_guidelines().await.unwrap();
229 assert_eq!(version, 0);
230 assert!(text.is_empty());
231 }
232
233 #[tokio::test]
234 async fn save_and_load_guidelines() {
235 let store = make_store().await;
236 let v1 = store
237 .save_compression_guidelines("always preserve file paths", 4)
238 .await
239 .unwrap();
240 assert_eq!(v1, 1);
241 let v2 = store
242 .save_compression_guidelines("always preserve file paths\nalways preserve errors", 8)
243 .await
244 .unwrap();
245 assert_eq!(v2, 2);
246 let (v, text) = store.load_compression_guidelines().await.unwrap();
248 assert_eq!(v, 2);
249 assert!(text.contains("errors"));
250 }
251
252 #[tokio::test]
253 async fn log_and_count_failure_pairs() {
254 let store = make_store().await;
255 let cid = ConversationId(store.create_conversation().await.unwrap().0);
256 store
257 .log_compression_failure(cid, "compressed ctx", "i don't recall that")
258 .await
259 .unwrap();
260 let count = store.count_unused_failure_pairs().await.unwrap();
261 assert_eq!(count, 1);
262 }
263
264 #[tokio::test]
265 async fn get_unused_pairs_sorted_oldest_first() {
266 let store = make_store().await;
267 let cid = ConversationId(store.create_conversation().await.unwrap().0);
268 store
269 .log_compression_failure(cid, "ctx A", "reason A")
270 .await
271 .unwrap();
272 store
273 .log_compression_failure(cid, "ctx B", "reason B")
274 .await
275 .unwrap();
276 let pairs = store.get_unused_failure_pairs(10).await.unwrap();
277 assert_eq!(pairs.len(), 2);
278 assert_eq!(pairs[0].compressed_context, "ctx A");
279 }
280
281 #[tokio::test]
282 async fn mark_pairs_used_reduces_count() {
283 let store = make_store().await;
284 let cid = ConversationId(store.create_conversation().await.unwrap().0);
285 let id = store
286 .log_compression_failure(cid, "ctx", "reason")
287 .await
288 .unwrap();
289 store.mark_failure_pairs_used(&[id]).await.unwrap();
290 let count = store.count_unused_failure_pairs().await.unwrap();
291 assert_eq!(count, 0);
292 }
293
294 #[tokio::test]
295 async fn cleanup_deletes_used_and_trims_unused() {
296 let store = make_store().await;
297 let cid = ConversationId(store.create_conversation().await.unwrap().0);
298 let id1 = store
300 .log_compression_failure(cid, "ctx1", "r1")
301 .await
302 .unwrap();
303 store
304 .log_compression_failure(cid, "ctx2", "r2")
305 .await
306 .unwrap();
307 store
308 .log_compression_failure(cid, "ctx3", "r3")
309 .await
310 .unwrap();
311 store.mark_failure_pairs_used(&[id1]).await.unwrap();
312 store.cleanup_old_failure_pairs(1).await.unwrap();
314 let count = store.count_unused_failure_pairs().await.unwrap();
315 assert_eq!(count, 1, "only 1 unused pair should remain");
316 }
317
318 #[tokio::test]
319 async fn truncate_field_respects_char_boundary() {
320 let s = "а".repeat(5000); let truncated = truncate_field(&s);
322 assert!(truncated.len() <= MAX_FIELD_CHARS);
323 assert!(s.is_char_boundary(truncated.len()));
324 }
325
326 #[tokio::test]
327 async fn unique_constraint_prevents_duplicate_version() {
328 let store = make_store().await;
329 store.save_compression_guidelines("first", 1).await.unwrap();
331 let result = sqlx::query(
334 "INSERT INTO compression_guidelines (version, guidelines, token_count) VALUES (1, 'dup', 0)",
335 )
336 .execute(store.pool())
337 .await;
338 assert!(
339 result.is_err(),
340 "duplicate version insert should violate UNIQUE constraint"
341 );
342 }
343
344 #[tokio::test]
350 async fn concurrent_saves_produce_unique_versions() {
351 use std::collections::HashSet;
352 use std::sync::Arc;
353
354 let dir = tempfile::tempdir().expect("tempdir");
355 let db_path = dir.path().join("test.db");
356 let store = Arc::new(
357 SqliteStore::with_pool_size(db_path.to_str().expect("utf8 path"), 4)
358 .await
359 .expect("file-backed SqliteStore"),
360 );
361
362 let tasks: Vec<_> = (0..8_i64)
363 .map(|i| {
364 let s = Arc::clone(&store);
365 tokio::spawn(async move {
366 s.save_compression_guidelines(&format!("guideline {i}"), i)
367 .await
368 .expect("concurrent save must succeed")
369 })
370 })
371 .collect();
372
373 let mut versions = HashSet::new();
374 for task in tasks {
375 let v = task.await.expect("task must not panic");
376 assert!(versions.insert(v), "version {v} appeared more than once");
377 }
378 assert_eq!(
379 versions.len(),
380 8,
381 "all 8 saves must produce distinct versions"
382 );
383 }
384}