1use std::borrow::Cow;
7use std::sync::LazyLock;
8
9use regex::Regex;
10
11use crate::error::MemoryError;
12use crate::sqlite::SqliteStore;
13use crate::types::ConversationId;
14
15static SECRET_RE: LazyLock<Regex> = LazyLock::new(|| {
16 Regex::new(
17 r#"(?:sk-|sk_live_|sk_test_|AKIA|ghp_|gho_|-----BEGIN|xoxb-|xoxp-|AIza|ya29\.|glpat-|hf_|npm_|dckr_pat_)[^\s"'`,;\{\}\[\]]*"#,
18 )
19 .expect("secret regex")
20});
21
22static PATH_RE: LazyLock<Regex> = LazyLock::new(|| {
23 Regex::new(r#"(?:/home/|/Users/|/root/|/tmp/|/var/)[^\s"'`,;\{\}\[\]]*"#).expect("path regex")
24});
25
26static BEARER_RE: LazyLock<Regex> =
28 LazyLock::new(|| Regex::new(r"(?i)(Authorization:\s*Bearer\s+)\S+").expect("bearer regex"));
29
30static JWT_RE: LazyLock<Regex> = LazyLock::new(|| {
33 Regex::new(r"eyJ[A-Za-z0-9_-]+\.[A-Za-z0-9_-]+\.[A-Za-z0-9_-]*").expect("jwt regex")
34});
35
36fn redact_sensitive(text: &str) -> Cow<'_, str> {
40 let s0: Cow<'_, str> = SECRET_RE.replace_all(text, "[REDACTED]");
44 let s1: Cow<'_, str> = match PATH_RE.replace_all(s0.as_ref(), "[PATH]") {
45 Cow::Borrowed(_) => s0,
46 Cow::Owned(o) => Cow::Owned(o),
47 };
48 let s2: Cow<'_, str> = match BEARER_RE.replace_all(s1.as_ref(), "${1}[REDACTED]") {
50 Cow::Borrowed(_) => s1,
51 Cow::Owned(o) => Cow::Owned(o),
52 };
53 match JWT_RE.replace_all(s2.as_ref(), "[REDACTED_JWT]") {
54 Cow::Borrowed(_) => s2,
55 Cow::Owned(o) => Cow::Owned(o),
56 }
57}
58
59#[derive(Debug, Clone)]
62pub struct CompressionFailurePair {
63 pub id: i64,
64 pub conversation_id: ConversationId,
65 pub compressed_context: String,
66 pub failure_reason: String,
67 pub created_at: String,
68}
69
70const MAX_FIELD_CHARS: usize = 4096;
72
73fn truncate_field(s: &str) -> &str {
74 let mut idx = MAX_FIELD_CHARS;
75 while idx > 0 && !s.is_char_boundary(idx) {
76 idx -= 1;
77 }
78 &s[..idx.min(s.len())]
79}
80
81impl SqliteStore {
82 pub async fn load_compression_guidelines(
93 &self,
94 conversation_id: Option<ConversationId>,
95 ) -> Result<(i64, String), MemoryError> {
96 let row = sqlx::query_as::<_, (i64, String)>(
97 "SELECT version, guidelines FROM compression_guidelines \
103 WHERE conversation_id = ? OR conversation_id IS NULL \
104 ORDER BY CASE WHEN conversation_id IS NOT NULL THEN 0 ELSE 1 END, \
105 version DESC \
106 LIMIT 1",
107 )
108 .bind(conversation_id.map(|c| c.0))
109 .fetch_optional(&self.pool)
110 .await?;
111
112 Ok(row.unwrap_or((0, String::new())))
113 }
114
115 pub async fn load_compression_guidelines_meta(
126 &self,
127 conversation_id: Option<ConversationId>,
128 ) -> Result<(i64, String), MemoryError> {
129 let row = sqlx::query_as::<_, (i64, String)>(
130 "SELECT version, created_at FROM compression_guidelines \
131 WHERE conversation_id = ? OR conversation_id IS NULL \
132 ORDER BY CASE WHEN conversation_id IS NOT NULL THEN 0 ELSE 1 END, \
133 version DESC \
134 LIMIT 1",
135 )
136 .bind(conversation_id.map(|c| c.0))
137 .fetch_optional(&self.pool)
138 .await?;
139
140 Ok(row.unwrap_or((0, String::new())))
141 }
142
143 pub async fn save_compression_guidelines(
160 &self,
161 guidelines: &str,
162 token_count: i64,
163 conversation_id: Option<ConversationId>,
164 ) -> Result<i64, MemoryError> {
165 let new_version: i64 = sqlx::query_scalar(
169 "INSERT INTO compression_guidelines (version, guidelines, token_count, conversation_id) \
170 SELECT COALESCE(MAX(version), 0) + 1, ?, ?, ? \
171 FROM compression_guidelines \
172 RETURNING version",
173 )
174 .bind(guidelines)
175 .bind(token_count)
176 .bind(conversation_id.map(|c| c.0))
177 .fetch_one(&self.pool)
178 .await?;
179 Ok(new_version)
180 }
181
182 pub async fn log_compression_failure(
191 &self,
192 conversation_id: ConversationId,
193 compressed_context: &str,
194 failure_reason: &str,
195 ) -> Result<i64, MemoryError> {
196 let ctx = redact_sensitive(compressed_context);
197 let ctx = truncate_field(&ctx);
198 let reason = redact_sensitive(failure_reason);
199 let reason = truncate_field(&reason);
200 let id = sqlx::query_scalar(
201 "INSERT INTO compression_failure_pairs \
202 (conversation_id, compressed_context, failure_reason) \
203 VALUES (?, ?, ?) RETURNING id",
204 )
205 .bind(conversation_id.0)
206 .bind(ctx)
207 .bind(reason)
208 .fetch_one(&self.pool)
209 .await?;
210 Ok(id)
211 }
212
213 pub async fn get_unused_failure_pairs(
219 &self,
220 limit: usize,
221 ) -> Result<Vec<CompressionFailurePair>, MemoryError> {
222 let limit = i64::try_from(limit).unwrap_or(i64::MAX);
223 let rows = sqlx::query_as::<_, (i64, i64, String, String, String)>(
224 "SELECT id, conversation_id, compressed_context, failure_reason, created_at \
225 FROM compression_failure_pairs \
226 WHERE used_in_update = 0 \
227 ORDER BY created_at ASC \
228 LIMIT ?",
229 )
230 .bind(limit)
231 .fetch_all(&self.pool)
232 .await?;
233
234 Ok(rows
235 .into_iter()
236 .map(
237 |(id, cid, ctx, reason, created_at)| CompressionFailurePair {
238 id,
239 conversation_id: ConversationId(cid),
240 compressed_context: ctx,
241 failure_reason: reason,
242 created_at,
243 },
244 )
245 .collect())
246 }
247
248 pub async fn mark_failure_pairs_used(&self, ids: &[i64]) -> Result<(), MemoryError> {
254 if ids.is_empty() {
255 return Ok(());
256 }
257 let placeholders: String = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
258 let query = format!(
259 "UPDATE compression_failure_pairs SET used_in_update = 1 WHERE id IN ({placeholders})"
260 );
261 let mut q = sqlx::query(&query);
262 for id in ids {
263 q = q.bind(id);
264 }
265 q.execute(&self.pool).await?;
266 Ok(())
267 }
268
269 pub async fn count_unused_failure_pairs(&self) -> Result<i64, MemoryError> {
275 let count = sqlx::query_scalar(
276 "SELECT COUNT(*) FROM compression_failure_pairs WHERE used_in_update = 0",
277 )
278 .fetch_one(&self.pool)
279 .await?;
280 Ok(count)
281 }
282
283 pub async fn cleanup_old_failure_pairs(&self, keep_recent: usize) -> Result<(), MemoryError> {
293 sqlx::query("DELETE FROM compression_failure_pairs WHERE used_in_update = 1")
295 .execute(&self.pool)
296 .await?;
297
298 let keep = i64::try_from(keep_recent).unwrap_or(i64::MAX);
300 sqlx::query(
301 "DELETE FROM compression_failure_pairs \
302 WHERE used_in_update = 0 \
303 AND id NOT IN ( \
304 SELECT id FROM compression_failure_pairs \
305 WHERE used_in_update = 0 \
306 ORDER BY created_at DESC \
307 LIMIT ? \
308 )",
309 )
310 .bind(keep)
311 .execute(&self.pool)
312 .await?;
313
314 Ok(())
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 async fn make_store() -> SqliteStore {
325 SqliteStore::with_pool_size(":memory:", 1)
326 .await
327 .expect("in-memory SqliteStore")
328 }
329
330 #[tokio::test]
331 async fn load_guidelines_meta_returns_defaults_when_empty() {
332 let store = make_store().await;
333 let (version, created_at) = store.load_compression_guidelines_meta(None).await.unwrap();
334 assert_eq!(version, 0);
335 assert!(created_at.is_empty());
336 }
337
338 #[tokio::test]
339 async fn load_guidelines_meta_returns_version_and_created_at() {
340 let store = make_store().await;
341 store
342 .save_compression_guidelines("keep file paths", 4, None)
343 .await
344 .unwrap();
345 let (version, created_at) = store.load_compression_guidelines_meta(None).await.unwrap();
346 assert_eq!(version, 1);
347 assert!(!created_at.is_empty(), "created_at should be populated");
348 }
349
350 #[tokio::test]
351 async fn load_guidelines_returns_defaults_when_empty() {
352 let store = make_store().await;
353 let (version, text) = store.load_compression_guidelines(None).await.unwrap();
354 assert_eq!(version, 0);
355 assert!(text.is_empty());
356 }
357
358 #[tokio::test]
359 async fn save_and_load_guidelines() {
360 let store = make_store().await;
361 let v1 = store
362 .save_compression_guidelines("always preserve file paths", 4, None)
363 .await
364 .unwrap();
365 assert_eq!(v1, 1);
366 let v2 = store
367 .save_compression_guidelines(
368 "always preserve file paths\nalways preserve errors",
369 8,
370 None,
371 )
372 .await
373 .unwrap();
374 assert_eq!(v2, 2);
375 let (v, text) = store.load_compression_guidelines(None).await.unwrap();
377 assert_eq!(v, 2);
378 assert!(text.contains("errors"));
379 }
380
381 #[tokio::test]
382 async fn load_guidelines_prefers_conversation_specific() {
383 let store = make_store().await;
384 let cid = ConversationId(store.create_conversation().await.unwrap().0);
385 store
386 .save_compression_guidelines("global rule", 2, None)
387 .await
388 .unwrap();
389 store
390 .save_compression_guidelines("conversation rule", 2, Some(cid))
391 .await
392 .unwrap();
393 let (_, text) = store.load_compression_guidelines(Some(cid)).await.unwrap();
394 assert_eq!(text, "conversation rule");
395 }
396
397 #[tokio::test]
398 async fn load_guidelines_falls_back_to_global() {
399 let store = make_store().await;
400 let cid = ConversationId(store.create_conversation().await.unwrap().0);
401 store
402 .save_compression_guidelines("global rule", 2, None)
403 .await
404 .unwrap();
405 let (_, text) = store.load_compression_guidelines(Some(cid)).await.unwrap();
407 assert_eq!(text, "global rule");
408 }
409
410 #[tokio::test]
411 async fn load_guidelines_none_returns_global_only() {
412 let store = make_store().await;
413 let cid = ConversationId(store.create_conversation().await.unwrap().0);
414 store
415 .save_compression_guidelines("conversation rule", 2, Some(cid))
416 .await
417 .unwrap();
418 let (version, text) = store.load_compression_guidelines(None).await.unwrap();
420 assert_eq!(version, 0);
421 assert!(text.is_empty());
422 }
423
424 #[tokio::test]
425 async fn load_guidelines_scope_isolation() {
426 let store = make_store().await;
427 let cid_a = ConversationId(store.create_conversation().await.unwrap().0);
428 let cid_b = ConversationId(store.create_conversation().await.unwrap().0);
429
430 store
432 .save_compression_guidelines("Use bullet points", 1, None)
433 .await
434 .unwrap();
435 store
437 .save_compression_guidelines("Be concise", 2, Some(cid_a))
438 .await
439 .unwrap();
440
441 let (_, text_b) = store
443 .load_compression_guidelines(Some(cid_b))
444 .await
445 .unwrap();
446 assert_eq!(
447 text_b, "Use bullet points",
448 "conversation B must see global guideline"
449 );
450
451 let (_, text_a) = store
453 .load_compression_guidelines(Some(cid_a))
454 .await
455 .unwrap();
456 assert_eq!(
457 text_a, "Be concise",
458 "conversation A must prefer its own guideline over global"
459 );
460
461 let (_, text_global) = store.load_compression_guidelines(None).await.unwrap();
463 assert_eq!(
464 text_global, "Use bullet points",
465 "None scope must see only the global guideline"
466 );
467 }
468
469 #[tokio::test]
470 async fn save_with_nonexistent_conversation_id_fails() {
471 let store = make_store().await;
472 let nonexistent = ConversationId(99999);
473 let result = store
474 .save_compression_guidelines("rule", 1, Some(nonexistent))
475 .await;
476 assert!(
477 result.is_err(),
478 "FK violation expected for nonexistent conversation_id"
479 );
480 }
481
482 #[tokio::test]
483 async fn cascade_delete_removes_conversation_guidelines() {
484 let store = make_store().await;
485 let cid = ConversationId(store.create_conversation().await.unwrap().0);
486 store
487 .save_compression_guidelines("rule", 1, Some(cid))
488 .await
489 .unwrap();
490 sqlx::query("DELETE FROM conversations WHERE id = ?")
492 .bind(cid.0)
493 .execute(store.pool())
494 .await
495 .unwrap();
496 let (version, text) = store.load_compression_guidelines(Some(cid)).await.unwrap();
497 assert_eq!(version, 0);
498 assert!(text.is_empty());
499 }
500
501 #[tokio::test]
502 async fn log_and_count_failure_pairs() {
503 let store = make_store().await;
504 let cid = ConversationId(store.create_conversation().await.unwrap().0);
505 store
506 .log_compression_failure(cid, "compressed ctx", "i don't recall that")
507 .await
508 .unwrap();
509 let count = store.count_unused_failure_pairs().await.unwrap();
510 assert_eq!(count, 1);
511 }
512
513 #[tokio::test]
514 async fn get_unused_pairs_sorted_oldest_first() {
515 let store = make_store().await;
516 let cid = ConversationId(store.create_conversation().await.unwrap().0);
517 store
518 .log_compression_failure(cid, "ctx A", "reason A")
519 .await
520 .unwrap();
521 store
522 .log_compression_failure(cid, "ctx B", "reason B")
523 .await
524 .unwrap();
525 let pairs = store.get_unused_failure_pairs(10).await.unwrap();
526 assert_eq!(pairs.len(), 2);
527 assert_eq!(pairs[0].compressed_context, "ctx A");
528 }
529
530 #[tokio::test]
531 async fn mark_pairs_used_reduces_count() {
532 let store = make_store().await;
533 let cid = ConversationId(store.create_conversation().await.unwrap().0);
534 let id = store
535 .log_compression_failure(cid, "ctx", "reason")
536 .await
537 .unwrap();
538 store.mark_failure_pairs_used(&[id]).await.unwrap();
539 let count = store.count_unused_failure_pairs().await.unwrap();
540 assert_eq!(count, 0);
541 }
542
543 #[tokio::test]
544 async fn cleanup_deletes_used_and_trims_unused() {
545 let store = make_store().await;
546 let cid = ConversationId(store.create_conversation().await.unwrap().0);
547 let id1 = store
549 .log_compression_failure(cid, "ctx1", "r1")
550 .await
551 .unwrap();
552 store
553 .log_compression_failure(cid, "ctx2", "r2")
554 .await
555 .unwrap();
556 store
557 .log_compression_failure(cid, "ctx3", "r3")
558 .await
559 .unwrap();
560 store.mark_failure_pairs_used(&[id1]).await.unwrap();
561 store.cleanup_old_failure_pairs(1).await.unwrap();
563 let count = store.count_unused_failure_pairs().await.unwrap();
564 assert_eq!(count, 1, "only 1 unused pair should remain");
565 }
566
567 #[test]
568 fn redact_sensitive_api_key_is_redacted() {
569 let result = redact_sensitive("token sk-abc123def456 used for auth");
570 assert!(result.contains("[REDACTED]"), "API key must be redacted");
571 assert!(
572 !result.contains("sk-abc123"),
573 "original key must not appear"
574 );
575 }
576
577 #[test]
578 fn redact_sensitive_plain_text_borrows() {
579 let text = "safe text, no secrets here";
580 let result = redact_sensitive(text);
581 assert!(
582 matches!(result, Cow::Borrowed(_)),
583 "plain text must return Cow::Borrowed (zero-alloc)"
584 );
585 }
586
587 #[test]
588 fn redact_sensitive_filesystem_path_is_redacted() {
589 let result = redact_sensitive("config loaded from /Users/dev/project/config.toml");
590 assert!(
591 result.contains("[PATH]"),
592 "filesystem path must be redacted"
593 );
594 assert!(
595 !result.contains("/Users/dev/"),
596 "original path must not appear"
597 );
598 }
599
600 #[test]
601 fn redact_sensitive_combined_secret_and_path() {
602 let result = redact_sensitive("key sk-abc at /home/user/file");
603 assert!(result.contains("[REDACTED]"), "secret must be redacted");
604 assert!(result.contains("[PATH]"), "path must be redacted");
605 }
606
607 #[tokio::test]
608 async fn log_compression_failure_redacts_secrets() {
609 let store = make_store().await;
610 let cid = ConversationId(store.create_conversation().await.unwrap().0);
611 store
612 .log_compression_failure(cid, "token sk-abc123def456 used for auth", "context lost")
613 .await
614 .unwrap();
615 let pairs = store.get_unused_failure_pairs(10).await.unwrap();
616 assert_eq!(pairs.len(), 1);
617 assert!(
618 pairs[0].compressed_context.contains("[REDACTED]"),
619 "stored context must have redacted secret"
620 );
621 assert!(
622 !pairs[0].compressed_context.contains("sk-abc123"),
623 "stored context must not contain raw secret"
624 );
625 }
626
627 #[tokio::test]
628 async fn log_compression_failure_redacts_paths() {
629 let store = make_store().await;
630 let cid = ConversationId(store.create_conversation().await.unwrap().0);
631 store
632 .log_compression_failure(cid, "/Users/dev/project/config.toml was loaded", "lost")
633 .await
634 .unwrap();
635 let pairs = store.get_unused_failure_pairs(10).await.unwrap();
636 assert!(
637 pairs[0].compressed_context.contains("[PATH]"),
638 "stored context must have redacted path"
639 );
640 assert!(
641 !pairs[0].compressed_context.contains("/Users/dev/"),
642 "stored context must not contain raw path"
643 );
644 }
645
646 #[tokio::test]
647 async fn log_compression_failure_reason_also_redacted() {
648 let store = make_store().await;
649 let cid = ConversationId(store.create_conversation().await.unwrap().0);
650 store
651 .log_compression_failure(cid, "some context", "secret ghp_abc123xyz was leaked")
652 .await
653 .unwrap();
654 let pairs = store.get_unused_failure_pairs(10).await.unwrap();
655 assert!(
656 pairs[0].failure_reason.contains("[REDACTED]"),
657 "failure_reason must also be redacted"
658 );
659 assert!(
660 !pairs[0].failure_reason.contains("ghp_abc123xyz"),
661 "raw secret must not appear in failure_reason"
662 );
663 }
664
665 #[tokio::test]
666 async fn truncate_field_respects_char_boundary() {
667 let s = "а".repeat(5000); let truncated = truncate_field(&s);
669 assert!(truncated.len() <= MAX_FIELD_CHARS);
670 assert!(s.is_char_boundary(truncated.len()));
671 }
672
673 #[tokio::test]
674 async fn unique_constraint_prevents_duplicate_version() {
675 let store = make_store().await;
676 store
678 .save_compression_guidelines("first", 1, None)
679 .await
680 .unwrap();
681 let result = sqlx::query(
684 "INSERT INTO compression_guidelines (version, guidelines, token_count) VALUES (1, 'dup', 0)",
685 )
686 .execute(store.pool())
687 .await;
688 assert!(
689 result.is_err(),
690 "duplicate version insert should violate UNIQUE constraint"
691 );
692 }
693
694 #[test]
695 fn redact_sensitive_bearer_token_is_redacted() {
696 let result =
697 redact_sensitive("Authorization: Bearer eyJhbGciOiJSUzI1NiJ9.payload.signature");
698 assert!(
699 result.contains("[REDACTED]"),
700 "Bearer token must be redacted: {result}"
701 );
702 assert!(
703 !result.contains("eyJhbGciOiJSUzI1NiJ9"),
704 "raw JWT header must not appear: {result}"
705 );
706 assert!(
707 result.contains("Authorization:"),
708 "header name must be preserved: {result}"
709 );
710 }
711
712 #[test]
713 fn redact_sensitive_bearer_token_case_insensitive() {
714 let result =
715 redact_sensitive("authorization: bearer eyJhbGciOiJSUzI1NiJ9.payload.signature");
716 assert!(
717 result.contains("[REDACTED]"),
718 "Bearer header match must be case-insensitive: {result}"
719 );
720 }
721
722 #[test]
723 fn redact_sensitive_standalone_jwt_is_redacted() {
724 let jwt = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyMTIzIn0.SflKxwRJSMeKKF2";
725 let input = format!("token value: {jwt} was found in logs");
726 let result = redact_sensitive(&input);
727 assert!(
728 result.contains("[REDACTED_JWT]"),
729 "standalone JWT must be replaced with [REDACTED_JWT]: {result}"
730 );
731 assert!(
732 !result.contains("eyJhbGci"),
733 "raw JWT must not appear: {result}"
734 );
735 }
736
737 #[test]
738 fn redact_sensitive_mixed_content_all_redacted() {
739 let input =
740 "key sk-abc123 at /home/user/f with Authorization: Bearer eyJhbG.pay.sig and eyJx.b.c";
741 let result = redact_sensitive(input);
742 assert!(result.contains("[REDACTED]"), "API key must be redacted");
743 assert!(result.contains("[PATH]"), "path must be redacted");
744 assert!(!result.contains("sk-abc123"), "raw API key must not appear");
745 assert!(!result.contains("eyJhbG"), "raw JWT must not appear");
746 }
747
748 #[test]
749 fn redact_sensitive_partial_jwt_not_redacted() {
750 let input = "eyJhbGciOiJSUzI1NiJ9.onlytwoparts";
752 let result = redact_sensitive(input);
753 assert!(
755 !result.contains("[REDACTED_JWT]"),
756 "two-part eyJ string must not be treated as JWT: {result}"
757 );
758 assert!(
760 matches!(result, Cow::Borrowed(_)),
761 "no-match input must return Cow::Borrowed: {result}"
762 );
763 }
764
765 #[test]
766 fn redact_sensitive_alg_none_jwt_empty_signature_redacted() {
767 let input =
769 "token: eyJhbGciOiJub25lIn0.eyJzdWIiOiJ1c2VyIn0. was submitted without signature";
770 let result = redact_sensitive(input);
771 assert!(
772 result.contains("[REDACTED_JWT]"),
773 "alg=none JWT with empty signature must be redacted: {result}"
774 );
775 assert!(
776 !result.contains("eyJhbGciOiJub25lIn0"),
777 "raw alg=none JWT header must not appear: {result}"
778 );
779 }
780
781 #[tokio::test]
787 async fn concurrent_saves_produce_unique_versions() {
788 use std::collections::HashSet;
789 use std::sync::Arc;
790
791 let dir = tempfile::tempdir().expect("tempdir");
792 let db_path = dir.path().join("test.db");
793 let store = Arc::new(
794 SqliteStore::with_pool_size(db_path.to_str().expect("utf8 path"), 4)
795 .await
796 .expect("file-backed SqliteStore"),
797 );
798
799 let tasks: Vec<_> = (0..8_i64)
800 .map(|i| {
801 let s = Arc::clone(&store);
802 tokio::spawn(async move {
803 s.save_compression_guidelines(&format!("guideline {i}"), i, None)
804 .await
805 .expect("concurrent save must succeed")
806 })
807 })
808 .collect();
809
810 let mut versions = HashSet::new();
811 for task in tasks {
812 let v = task.await.expect("task must not panic");
813 assert!(versions.insert(v), "version {v} appeared more than once");
814 }
815 assert_eq!(
816 versions.len(),
817 8,
818 "all 8 saves must produce distinct versions"
819 );
820 }
821}