1use std::time::Instant;
12
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize};
15use zeph_llm::any::AnyProvider;
16use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, MessagePart, Role};
17
18use crate::error::MemoryError;
19
20#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
24pub struct ProbeQuestion {
25 pub question: String,
27 pub expected_answer: String,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
33pub enum ProbeVerdict {
34 Pass,
36 SoftFail,
39 HardFail,
41 Error,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct CompactionProbeResult {
48 pub score: f32,
50 pub questions: Vec<ProbeQuestion>,
52 pub answers: Vec<String>,
54 pub per_question_scores: Vec<f32>,
56 pub verdict: ProbeVerdict,
57 pub threshold: f32,
59 pub hard_fail_threshold: f32,
61 pub model: String,
63 pub duration_ms: u64,
65}
66
67#[derive(Debug, Deserialize, JsonSchema)]
70struct ProbeQuestionsOutput {
71 questions: Vec<ProbeQuestion>,
72}
73
74#[derive(Debug, Deserialize, JsonSchema)]
75struct ProbeAnswersOutput {
76 answers: Vec<String>,
77}
78
79const REFUSAL_PATTERNS: &[&str] = &[
83 "unknown",
84 "not mentioned",
85 "not found",
86 "n/a",
87 "cannot determine",
88 "no information",
89 "not provided",
90 "not specified",
91 "not stated",
92 "not available",
93];
94
95fn is_refusal(text: &str) -> bool {
96 let lower = text.to_lowercase();
97 REFUSAL_PATTERNS.iter().any(|p| lower.contains(p))
98}
99
100fn normalize_tokens(text: &str) -> Vec<String> {
102 text.to_lowercase()
103 .split(|c: char| !c.is_alphanumeric())
104 .filter(|t| t.len() >= 3)
105 .map(String::from)
106 .collect()
107}
108
109fn jaccard(a: &[String], b: &[String]) -> f32 {
110 if a.is_empty() && b.is_empty() {
111 return 1.0;
112 }
113 let set_a: std::collections::HashSet<&str> = a.iter().map(String::as_str).collect();
114 let set_b: std::collections::HashSet<&str> = b.iter().map(String::as_str).collect();
115 let intersection = set_a.intersection(&set_b).count();
116 let union = set_a.union(&set_b).count();
117 if union == 0 {
118 return 0.0;
119 }
120 #[allow(clippy::cast_precision_loss)]
121 {
122 intersection as f32 / union as f32
123 }
124}
125
126fn score_pair(expected: &str, actual: &str) -> f32 {
128 if is_refusal(actual) {
129 return 0.0;
130 }
131
132 let tokens_e = normalize_tokens(expected);
133 let tokens_a = normalize_tokens(actual);
134
135 if !tokens_e.is_empty() {
137 let set_e: std::collections::HashSet<&str> = tokens_e.iter().map(String::as_str).collect();
138 let set_a: std::collections::HashSet<&str> = tokens_a.iter().map(String::as_str).collect();
139 if set_e.is_subset(&set_a) {
140 return 1.0;
141 }
142 }
143
144 let j_full = jaccard(&tokens_e, &tokens_a);
146
147 let set_e: std::collections::HashSet<&str> = tokens_e.iter().map(String::as_str).collect();
149 let set_a: std::collections::HashSet<&str> = tokens_a.iter().map(String::as_str).collect();
150 let intersection: Vec<String> = set_e
151 .intersection(&set_a)
152 .map(|s| (*s).to_owned())
153 .collect();
154
155 #[allow(clippy::cast_precision_loss)]
156 let j_e = if tokens_e.is_empty() {
157 0.0_f32
158 } else {
159 intersection.len() as f32 / tokens_e.len() as f32
160 };
161 #[allow(clippy::cast_precision_loss)]
162 let j_a = if tokens_a.is_empty() {
163 0.0_f32
164 } else {
165 intersection.len() as f32 / tokens_a.len() as f32
166 };
167
168 j_full.max(j_e).max(j_a)
169}
170
171#[must_use]
175pub fn score_answers(questions: &[ProbeQuestion], answers: &[String]) -> (Vec<f32>, f32) {
176 if questions.is_empty() {
177 return (vec![], 0.0);
178 }
179 let scores: Vec<f32> = questions
180 .iter()
181 .zip(answers.iter().chain(std::iter::repeat(&String::new())))
182 .map(|(q, a)| score_pair(&q.expected_answer, a))
183 .collect();
184 #[allow(clippy::cast_precision_loss)]
185 let avg = if scores.is_empty() {
186 0.0
187 } else {
188 scores.iter().sum::<f32>() / scores.len() as f32
189 };
190 (scores, avg)
191}
192
193fn truncate_tool_bodies(messages: &[Message]) -> Vec<Message> {
197 messages
198 .iter()
199 .map(|m| {
200 let mut msg = m.clone();
201 for part in &mut msg.parts {
202 if let MessagePart::ToolOutput { body, .. } = part {
203 if body.len() <= 500 {
204 continue;
205 }
206 body.truncate(500);
207 body.push('\u{2026}');
208 }
209 }
210 msg.rebuild_content();
211 msg
212 })
213 .collect()
214}
215
216pub async fn generate_probe_questions(
225 provider: &AnyProvider,
226 messages: &[Message],
227 max_questions: usize,
228) -> Result<Vec<ProbeQuestion>, MemoryError> {
229 let truncated = truncate_tool_bodies(messages);
230
231 let mut history = String::new();
232 for msg in &truncated {
233 let role = match msg.role {
234 Role::User => "user",
235 Role::Assistant => "assistant",
236 Role::System => "system",
237 };
238 history.push_str(role);
239 history.push_str(": ");
240 history.push_str(&msg.content);
241 history.push('\n');
242 }
243
244 let prompt = format!(
245 "Given the following conversation excerpt, generate {max_questions} factual questions \
246 that test whether a summary preserves the most important concrete details.\n\
247 \n\
248 Focus on:\n\
249 - File paths, function names, struct/enum names that were modified or discussed\n\
250 - Architectural or implementation decisions with their rationale\n\
251 - Config values, API endpoints, error messages that were significant\n\
252 - Action items or next steps agreed upon\n\
253 \n\
254 Do NOT generate questions about:\n\
255 - Raw tool output content (compiler warnings, test output line numbers)\n\
256 - Intermediate debugging steps that were superseded\n\
257 - Opinions or reasoning that cannot be verified\n\
258 \n\
259 Each question must have a single unambiguous expected answer extractable from the text.\n\
260 \n\
261 Conversation:\n{history}\n\
262 \n\
263 Respond in JSON with schema: {{\"questions\": [{{\"question\": \"...\", \
264 \"expected_answer\": \"...\"}}]}}"
265 );
266
267 let msgs = [Message {
268 role: Role::User,
269 content: prompt,
270 parts: vec![],
271 metadata: MessageMetadata::default(),
272 }];
273
274 let mut output: ProbeQuestionsOutput = provider
275 .chat_typed_erased::<ProbeQuestionsOutput>(&msgs)
276 .await
277 .map_err(MemoryError::Llm)?;
278
279 output.questions.truncate(max_questions);
281
282 Ok(output.questions)
283}
284
285pub async fn answer_probe_questions(
291 provider: &AnyProvider,
292 summary: &str,
293 questions: &[ProbeQuestion],
294) -> Result<Vec<String>, MemoryError> {
295 let mut numbered = String::new();
296 for (i, q) in questions.iter().enumerate() {
297 use std::fmt::Write as _;
298 let _ = writeln!(numbered, "{}. {}", i + 1, q.question);
299 }
300
301 let prompt = format!(
302 "Given the following summary of a conversation, answer each question using ONLY \
303 information present in the summary. If the answer is not in the summary, respond \
304 with \"UNKNOWN\".\n\
305 \n\
306 Summary:\n{summary}\n\
307 \n\
308 Questions:\n{numbered}\n\
309 \n\
310 Respond in JSON with schema: {{\"answers\": [\"answer1\", \"answer2\", ...]}}"
311 );
312
313 let msgs = [Message {
314 role: Role::User,
315 content: prompt,
316 parts: vec![],
317 metadata: MessageMetadata::default(),
318 }];
319
320 let output: ProbeAnswersOutput = provider
321 .chat_typed_erased::<ProbeAnswersOutput>(&msgs)
322 .await
323 .map_err(MemoryError::Llm)?;
324
325 Ok(output.answers)
326}
327
328#[derive(Debug, Clone, Serialize, Deserialize)]
330#[serde(default)]
331pub struct CompactionProbeConfig {
332 pub enabled: bool,
334 pub model: String,
339 pub threshold: f32,
342 pub hard_fail_threshold: f32,
344 pub max_questions: usize,
346 pub timeout_secs: u64,
348}
349
350impl Default for CompactionProbeConfig {
351 fn default() -> Self {
352 Self {
353 enabled: false,
354 model: String::new(),
355 threshold: 0.6,
356 hard_fail_threshold: 0.35,
357 max_questions: 3,
358 timeout_secs: 15,
359 }
360 }
361}
362
363pub async fn validate_compaction(
377 provider: &AnyProvider,
378 messages: &[Message],
379 summary: &str,
380 config: &CompactionProbeConfig,
381) -> Result<Option<CompactionProbeResult>, MemoryError> {
382 if !config.enabled {
383 return Ok(None);
384 }
385
386 let timeout = std::time::Duration::from_secs(config.timeout_secs);
387 let start = Instant::now();
388
389 let result = tokio::time::timeout(timeout, async {
390 run_probe(provider, messages, summary, config).await
391 })
392 .await;
393
394 match result {
395 Ok(inner) => inner,
396 Err(_elapsed) => {
397 tracing::warn!(
398 timeout_secs = config.timeout_secs,
399 "compaction probe timed out — proceeding with compaction"
400 );
401 Ok(None)
402 }
403 }
404 .map(|opt| {
405 opt.map(|mut r| {
406 r.duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
407 r
408 })
409 })
410}
411
412async fn run_probe(
413 provider: &AnyProvider,
414 messages: &[Message],
415 summary: &str,
416 config: &CompactionProbeConfig,
417) -> Result<Option<CompactionProbeResult>, MemoryError> {
418 let questions = generate_probe_questions(provider, messages, config.max_questions).await?;
419
420 if questions.len() < 2 {
421 tracing::debug!(
422 count = questions.len(),
423 "compaction probe: fewer than 2 questions generated — skipping probe"
424 );
425 return Ok(None);
426 }
427
428 let answers = answer_probe_questions(provider, summary, &questions).await?;
429
430 let (per_question_scores, score) = score_answers(&questions, &answers);
431
432 let verdict = if score >= config.threshold {
433 ProbeVerdict::Pass
434 } else if score >= config.hard_fail_threshold {
435 ProbeVerdict::SoftFail
436 } else {
437 ProbeVerdict::HardFail
438 };
439
440 let model = provider.name().to_owned();
441
442 Ok(Some(CompactionProbeResult {
443 score,
444 questions,
445 answers,
446 per_question_scores,
447 verdict,
448 threshold: config.threshold,
449 hard_fail_threshold: config.hard_fail_threshold,
450 model,
451 duration_ms: 0, }))
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458
459 #[test]
462 fn score_perfect_match() {
463 let q = vec![ProbeQuestion {
464 question: "What crate is used?".into(),
465 expected_answer: "thiserror".into(),
466 }];
467 let a = vec!["thiserror".into()];
468 let (scores, avg) = score_answers(&q, &a);
469 assert_eq!(scores.len(), 1);
470 assert!((avg - 1.0).abs() < 0.01, "expected ~1.0, got {avg}");
471 }
472
473 #[test]
474 fn score_complete_mismatch() {
475 let q = vec![ProbeQuestion {
476 question: "What file was modified?".into(),
477 expected_answer: "src/auth.rs".into(),
478 }];
479 let a = vec!["definitely not in the summary".into()];
480 let (scores, avg) = score_answers(&q, &a);
481 assert_eq!(scores.len(), 1);
482 assert!(avg < 0.5, "expected low score, got {avg}");
484 }
485
486 #[test]
487 fn score_refusal_is_zero() {
488 let q = vec![ProbeQuestion {
489 question: "What was the decision?".into(),
490 expected_answer: "Use thiserror for typed errors".into(),
491 }];
492 for refusal in &[
493 "UNKNOWN",
494 "not mentioned",
495 "N/A",
496 "cannot determine",
497 "No information",
498 ] {
499 let a = vec![(*refusal).to_owned()];
500 let (_, avg) = score_answers(&q, &a);
501 assert!(avg < 0.01, "expected 0 for refusal '{refusal}', got {avg}");
502 }
503 }
504
505 #[test]
506 fn score_paraphrased_answer_above_half() {
507 let q = vec![ProbeQuestion {
510 question: "What error handling crate was chosen?".into(),
511 expected_answer: "Use thiserror for typed errors in library crates".into(),
512 }];
513 let a = vec!["thiserror was chosen for error types in library crates".into()];
514 let (_, avg) = score_answers(&q, &a);
515 assert!(avg > 0.5, "expected >0.5 for paraphrase, got {avg}");
516 }
517
518 #[test]
519 fn score_empty_strings() {
520 let q = vec![ProbeQuestion {
521 question: "What?".into(),
522 expected_answer: String::new(),
523 }];
524 let a = vec![String::new()];
525 let (scores, avg) = score_answers(&q, &a);
526 assert_eq!(scores.len(), 1);
527 assert!(
529 (avg - 1.0).abs() < 0.01,
530 "expected 1.0 for empty vs empty, got {avg}"
531 );
532 }
533
534 #[test]
535 fn score_empty_questions_list() {
536 let (scores, avg) = score_answers(&[], &[]);
537 assert!(scores.is_empty());
538 assert!((avg - 0.0).abs() < 0.01);
539 }
540
541 #[test]
542 fn score_file_path_exact() {
543 let q = vec![ProbeQuestion {
544 question: "Which file was modified?".into(),
545 expected_answer: "crates/zeph-memory/src/compaction_probe.rs".into(),
546 }];
547 let a = vec!["The file crates/zeph-memory/src/compaction_probe.rs was modified.".into()];
548 let (_, avg) = score_answers(&q, &a);
549 assert!(
551 avg > 0.8,
552 "expected high score for file path match, got {avg}"
553 );
554 }
555
556 #[test]
557 fn score_unicode_input() {
558 let q = vec![ProbeQuestion {
559 question: "Что было изменено?".into(),
560 expected_answer: "файл config.toml".into(),
561 }];
562 let a = vec!["config.toml был изменён".into()];
563 let (scores, _) = score_answers(&q, &a);
565 assert_eq!(scores.len(), 1);
566 }
567
568 #[test]
571 fn verdict_thresholds() {
572 let config = CompactionProbeConfig::default();
573
574 let score = 0.7_f32;
576 let verdict = if score >= config.threshold {
577 ProbeVerdict::Pass
578 } else if score >= config.hard_fail_threshold {
579 ProbeVerdict::SoftFail
580 } else {
581 ProbeVerdict::HardFail
582 };
583 assert_eq!(verdict, ProbeVerdict::Pass);
584
585 let score = 0.5_f32;
587 let verdict = if score >= config.threshold {
588 ProbeVerdict::Pass
589 } else if score >= config.hard_fail_threshold {
590 ProbeVerdict::SoftFail
591 } else {
592 ProbeVerdict::HardFail
593 };
594 assert_eq!(verdict, ProbeVerdict::SoftFail);
595
596 let score = 0.2_f32;
598 let verdict = if score >= config.threshold {
599 ProbeVerdict::Pass
600 } else if score >= config.hard_fail_threshold {
601 ProbeVerdict::SoftFail
602 } else {
603 ProbeVerdict::HardFail
604 };
605 assert_eq!(verdict, ProbeVerdict::HardFail);
606 }
607
608 #[test]
611 fn config_defaults() {
612 let c = CompactionProbeConfig::default();
613 assert!(!c.enabled);
614 assert!(c.model.is_empty());
615 assert!((c.threshold - 0.6).abs() < 0.001);
616 assert!((c.hard_fail_threshold - 0.35).abs() < 0.001);
617 assert_eq!(c.max_questions, 3);
618 assert_eq!(c.timeout_secs, 15);
619 }
620
621 #[test]
624 fn config_serde_round_trip() {
625 let original = CompactionProbeConfig {
626 enabled: true,
627 model: "claude-haiku-4-5-20251001".into(),
628 threshold: 0.65,
629 hard_fail_threshold: 0.4,
630 max_questions: 5,
631 timeout_secs: 20,
632 };
633 let json = serde_json::to_string(&original).expect("serialize");
634 let restored: CompactionProbeConfig = serde_json::from_str(&json).expect("deserialize");
635 assert!(restored.enabled);
636 assert_eq!(restored.model, "claude-haiku-4-5-20251001");
637 assert!((restored.threshold - 0.65).abs() < 0.001);
638 }
639
640 #[test]
641 fn probe_result_serde_round_trip() {
642 let result = CompactionProbeResult {
643 score: 0.75,
644 questions: vec![ProbeQuestion {
645 question: "What?".into(),
646 expected_answer: "thiserror".into(),
647 }],
648 answers: vec!["thiserror".into()],
649 per_question_scores: vec![1.0],
650 verdict: ProbeVerdict::Pass,
651 threshold: 0.6,
652 hard_fail_threshold: 0.35,
653 model: "haiku".into(),
654 duration_ms: 1234,
655 };
656 let json = serde_json::to_string(&result).expect("serialize");
657 let restored: CompactionProbeResult = serde_json::from_str(&json).expect("deserialize");
658 assert!((restored.score - 0.75).abs() < 0.001);
659 assert_eq!(restored.verdict, ProbeVerdict::Pass);
660 }
661
662 #[test]
665 fn score_fewer_answers_than_questions() {
666 let questions = vec![
667 ProbeQuestion {
668 question: "What crate?".into(),
669 expected_answer: "thiserror".into(),
670 },
671 ProbeQuestion {
672 question: "What file?".into(),
673 expected_answer: "src/lib.rs".into(),
674 },
675 ProbeQuestion {
676 question: "What decision?".into(),
677 expected_answer: "use async traits".into(),
678 },
679 ];
680 let answers = vec!["thiserror".into()];
682 let (scores, avg) = score_answers(&questions, &answers);
683 assert_eq!(scores.len(), 3);
685 assert!(
687 (scores[0] - 1.0).abs() < 0.01,
688 "first score should be ~1.0, got {}",
689 scores[0]
690 );
691 assert!(
693 scores[1] < 0.5,
694 "second score should be low for missing answer, got {}",
695 scores[1]
696 );
697 assert!(
698 scores[2] < 0.5,
699 "third score should be low for missing answer, got {}",
700 scores[2]
701 );
702 assert!(
704 avg < 0.5,
705 "average should be below 0.5 with 2 missing answers, got {avg}"
706 );
707 }
708
709 #[test]
712 fn verdict_boundary_at_threshold() {
713 let config = CompactionProbeConfig::default();
714
715 let score = config.threshold;
717 let verdict = if score >= config.threshold {
718 ProbeVerdict::Pass
719 } else if score >= config.hard_fail_threshold {
720 ProbeVerdict::SoftFail
721 } else {
722 ProbeVerdict::HardFail
723 };
724 assert_eq!(verdict, ProbeVerdict::Pass);
725
726 let score = config.threshold - f32::EPSILON;
728 let verdict = if score >= config.threshold {
729 ProbeVerdict::Pass
730 } else if score >= config.hard_fail_threshold {
731 ProbeVerdict::SoftFail
732 } else {
733 ProbeVerdict::HardFail
734 };
735 assert_eq!(verdict, ProbeVerdict::SoftFail);
736
737 let score = config.hard_fail_threshold;
739 let verdict = if score >= config.threshold {
740 ProbeVerdict::Pass
741 } else if score >= config.hard_fail_threshold {
742 ProbeVerdict::SoftFail
743 } else {
744 ProbeVerdict::HardFail
745 };
746 assert_eq!(verdict, ProbeVerdict::SoftFail);
747
748 let score = config.hard_fail_threshold - f32::EPSILON;
750 let verdict = if score >= config.threshold {
751 ProbeVerdict::Pass
752 } else if score >= config.hard_fail_threshold {
753 ProbeVerdict::SoftFail
754 } else {
755 ProbeVerdict::HardFail
756 };
757 assert_eq!(verdict, ProbeVerdict::HardFail);
758 }
759
760 #[test]
763 fn config_partial_json_uses_defaults() {
764 let json = r#"{"enabled": true}"#;
766 let c: CompactionProbeConfig =
767 serde_json::from_str(json).expect("deserialize partial json");
768 assert!(c.enabled);
769 assert!(c.model.is_empty());
770 assert!((c.threshold - 0.6).abs() < 0.001);
771 assert!((c.hard_fail_threshold - 0.35).abs() < 0.001);
772 assert_eq!(c.max_questions, 3);
773 assert_eq!(c.timeout_secs, 15);
774 }
775
776 #[test]
777 fn config_empty_json_uses_all_defaults() {
778 let c: CompactionProbeConfig = serde_json::from_str("{}").expect("deserialize empty json");
779 assert!(!c.enabled);
780 assert!(c.model.is_empty());
781 }
782}