1use std::collections::HashMap;
12use std::time::Instant;
13
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16use zeph_llm::any::AnyProvider;
17use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, MessagePart, Role};
18
19use crate::error::MemoryError;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
25#[serde(rename_all = "lowercase")]
26pub enum ProbeCategory {
27 Recall,
29 Artifact,
31 Continuation,
33 Decision,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct CategoryScore {
40 pub category: ProbeCategory,
41 pub score: f32,
43 pub probes_run: u32,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
49pub struct ProbeQuestion {
50 pub question: String,
52 pub expected_answer: String,
54 #[serde(default = "default_probe_category")]
56 pub category: ProbeCategory,
57}
58
59fn default_probe_category() -> ProbeCategory {
60 ProbeCategory::Recall
61}
62
63impl Default for ProbeQuestion {
64 fn default() -> Self {
65 Self {
66 question: String::new(),
67 expected_answer: String::new(),
68 category: ProbeCategory::Recall,
69 }
70 }
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
75pub enum ProbeVerdict {
76 Pass,
78 SoftFail,
81 HardFail,
83 Error,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct CompactionProbeResult {
90 pub score: f32,
92 #[serde(default)]
94 pub category_scores: Vec<CategoryScore>,
95 pub questions: Vec<ProbeQuestion>,
97 pub answers: Vec<String>,
99 pub per_question_scores: Vec<f32>,
101 pub verdict: ProbeVerdict,
102 pub threshold: f32,
104 pub hard_fail_threshold: f32,
106 pub model: String,
108 pub duration_ms: u64,
110}
111
112#[derive(Debug, Deserialize, JsonSchema)]
115struct ProbeQuestionsOutput {
116 questions: Vec<ProbeQuestion>,
117}
118
119#[must_use]
125fn compute_category_scores(
126 questions: &[ProbeQuestion],
127 per_question_scores: &[f32],
128 category_weights: Option<&HashMap<ProbeCategory, f32>>,
129) -> (Vec<CategoryScore>, f32) {
130 let mut by_cat: HashMap<ProbeCategory, Vec<f32>> = HashMap::new();
132 for (q, &s) in questions.iter().zip(per_question_scores.iter()) {
133 by_cat.entry(q.category).or_default().push(s);
134 }
135
136 #[allow(clippy::cast_precision_loss)]
137 let category_scores: Vec<CategoryScore> = by_cat
138 .into_iter()
139 .map(|(category, scores)| {
140 let avg = scores.iter().sum::<f32>() / scores.len() as f32;
141 CategoryScore {
142 category,
143 score: avg,
144 #[allow(clippy::cast_possible_truncation)]
145 probes_run: scores.len() as u32,
146 }
147 })
148 .collect();
149
150 if category_scores.is_empty() {
151 return (category_scores, 0.0);
152 }
153
154 let mut weighted_sum = 0.0_f32;
156 let mut weight_total = 0.0_f32;
157 for cs in &category_scores {
158 let raw_w = category_weights
159 .and_then(|m| m.get(&cs.category).copied())
160 .unwrap_or(1.0);
161 if raw_w < 0.0 {
162 tracing::warn!(
163 category = ?cs.category,
164 weight = raw_w,
165 "category_weights contains a negative value — treating as 0.0 (category excluded from scoring)"
166 );
167 }
168 let w = raw_w.max(0.0);
169 weighted_sum += cs.score * w;
170 weight_total += w;
171 }
172
173 let overall = if weight_total > 0.0 {
174 weighted_sum / weight_total
175 } else {
176 #[allow(clippy::cast_precision_loss)]
178 let n = category_scores.len() as f32;
179 category_scores.iter().map(|cs| cs.score).sum::<f32>() / n
180 };
181
182 (category_scores, overall)
183}
184
185#[derive(Debug, Deserialize, JsonSchema)]
186struct ProbeAnswersOutput {
187 answers: Vec<String>,
188}
189
190const REFUSAL_PATTERNS: &[&str] = &[
194 "unknown",
195 "not mentioned",
196 "not found",
197 "n/a",
198 "cannot determine",
199 "no information",
200 "not provided",
201 "not specified",
202 "not stated",
203 "not available",
204];
205
206fn is_refusal(text: &str) -> bool {
207 let lower = text.to_lowercase();
208 REFUSAL_PATTERNS.iter().any(|p| lower.contains(p))
209}
210
211fn normalize_tokens(text: &str) -> Vec<String> {
213 text.to_lowercase()
214 .split(|c: char| !c.is_alphanumeric())
215 .filter(|t| t.len() >= 3)
216 .map(String::from)
217 .collect()
218}
219
220fn jaccard(a: &[String], b: &[String]) -> f32 {
221 if a.is_empty() && b.is_empty() {
222 return 1.0;
223 }
224 let set_a: std::collections::HashSet<&str> = a.iter().map(String::as_str).collect();
225 let set_b: std::collections::HashSet<&str> = b.iter().map(String::as_str).collect();
226 let intersection = set_a.intersection(&set_b).count();
227 let union = set_a.union(&set_b).count();
228 if union == 0 {
229 return 0.0;
230 }
231 #[allow(clippy::cast_precision_loss)]
232 {
233 intersection as f32 / union as f32
234 }
235}
236
237fn score_pair(expected: &str, actual: &str) -> f32 {
239 if is_refusal(actual) {
240 return 0.0;
241 }
242
243 let tokens_e = normalize_tokens(expected);
244 let tokens_a = normalize_tokens(actual);
245
246 if !tokens_e.is_empty() {
248 let set_e: std::collections::HashSet<&str> = tokens_e.iter().map(String::as_str).collect();
249 let set_a: std::collections::HashSet<&str> = tokens_a.iter().map(String::as_str).collect();
250 if set_e.is_subset(&set_a) {
251 return 1.0;
252 }
253 }
254
255 let j_full = jaccard(&tokens_e, &tokens_a);
257
258 let set_e: std::collections::HashSet<&str> = tokens_e.iter().map(String::as_str).collect();
260 let set_a: std::collections::HashSet<&str> = tokens_a.iter().map(String::as_str).collect();
261 let intersection: Vec<String> = set_e
262 .intersection(&set_a)
263 .map(|s| (*s).to_owned())
264 .collect();
265
266 #[allow(clippy::cast_precision_loss)]
267 let j_e = if tokens_e.is_empty() {
268 0.0_f32
269 } else {
270 intersection.len() as f32 / tokens_e.len() as f32
271 };
272 #[allow(clippy::cast_precision_loss)]
273 let j_a = if tokens_a.is_empty() {
274 0.0_f32
275 } else {
276 intersection.len() as f32 / tokens_a.len() as f32
277 };
278
279 j_full.max(j_e).max(j_a)
280}
281
282#[must_use]
286pub fn score_answers(questions: &[ProbeQuestion], answers: &[String]) -> (Vec<f32>, f32) {
287 if questions.is_empty() {
288 return (vec![], 0.0);
289 }
290 let scores: Vec<f32> = questions
291 .iter()
292 .zip(answers.iter().chain(std::iter::repeat(&String::new())))
293 .map(|(q, a)| score_pair(&q.expected_answer, a))
294 .collect();
295 #[allow(clippy::cast_precision_loss)]
296 let avg = if scores.is_empty() {
297 0.0
298 } else {
299 scores.iter().sum::<f32>() / scores.len() as f32
300 };
301 (scores, avg)
302}
303
304fn truncate_tool_bodies(messages: &[Message]) -> Vec<Message> {
308 messages
309 .iter()
310 .map(|m| {
311 let mut msg = m.clone();
312 for part in &mut msg.parts {
313 if let MessagePart::ToolOutput { body, .. } = part {
314 if body.len() <= 500 {
315 continue;
316 }
317 body.truncate(500);
318 body.push('\u{2026}');
319 }
320 }
321 msg.rebuild_content();
322 msg
323 })
324 .collect()
325}
326
327#[cfg_attr(
336 feature = "profiling",
337 tracing::instrument(name = "memory.compaction_probe", skip_all)
338)]
339pub async fn generate_probe_questions(
340 provider: &AnyProvider,
341 messages: &[Message],
342 max_questions: usize,
343) -> Result<Vec<ProbeQuestion>, MemoryError> {
344 let truncated = truncate_tool_bodies(messages);
345
346 let mut history = String::new();
347 for msg in &truncated {
348 let role = match msg.role {
349 Role::User => "user",
350 Role::Assistant => "assistant",
351 Role::System => "system",
352 };
353 history.push_str(role);
354 history.push_str(": ");
355 history.push_str(&msg.content);
356 history.push('\n');
357 }
358
359 let prompt = format!(
360 "Given the following conversation excerpt, generate {max_questions} factual questions \
361 that test whether a summary preserves the most important concrete details.\n\
362 \n\
363 You MUST generate at least one question per category when max_questions >= 4. \
364 If the conversation lacks information for a category, generate a question noting that absence.\n\
365 \n\
366 Categories:\n\
367 - recall: Specific facts that survived (file paths, function names, values). \
368 Example: \"What file was modified?\"\n\
369 - artifact: Which files/tools/URLs the agent used. \
370 Example: \"Which tool was executed?\"\n\
371 - continuation: Next steps, blockers, open questions. \
372 Example: \"What is the next step?\"\n\
373 - decision: Past reasoning traces (why X over Y, trade-offs). \
374 Example: \"Why was X chosen over Y?\"\n\
375 \n\
376 Do NOT generate questions about:\n\
377 - Raw tool output content (compiler warnings, test output line numbers)\n\
378 - Intermediate debugging steps that were superseded\n\
379 - Opinions or reasoning that cannot be verified\n\
380 \n\
381 Each question must have a single unambiguous expected answer extractable from the text.\n\
382 \n\
383 Conversation:\n{history}\n\
384 \n\
385 Respond in JSON with schema: {{\"questions\": [{{\"question\": \"...\", \
386 \"expected_answer\": \"...\", \"category\": \"recall|artifact|continuation|decision\"}}]}}"
387 );
388
389 let msgs = [Message {
390 role: Role::User,
391 content: prompt,
392 parts: vec![],
393 metadata: MessageMetadata::default(),
394 }];
395
396 let mut output: ProbeQuestionsOutput = provider
397 .chat_typed_erased::<ProbeQuestionsOutput>(&msgs)
398 .await
399 .map_err(MemoryError::Llm)?;
400
401 output.questions.truncate(max_questions);
403
404 Ok(output.questions)
405}
406
407pub async fn answer_probe_questions(
413 provider: &AnyProvider,
414 summary: &str,
415 questions: &[ProbeQuestion],
416) -> Result<Vec<String>, MemoryError> {
417 let mut numbered = String::new();
418 for (i, q) in questions.iter().enumerate() {
419 use std::fmt::Write as _;
420 let _ = writeln!(numbered, "{}. {}", i + 1, q.question);
421 }
422
423 let prompt = format!(
424 "Given the following summary of a conversation, answer each question using ONLY \
425 information present in the summary. If the answer is not in the summary, respond \
426 with \"UNKNOWN\".\n\
427 \n\
428 Summary:\n{summary}\n\
429 \n\
430 Questions:\n{numbered}\n\
431 \n\
432 Respond in JSON with schema: {{\"answers\": [\"answer1\", \"answer2\", ...]}}"
433 );
434
435 let msgs = [Message {
436 role: Role::User,
437 content: prompt,
438 parts: vec![],
439 metadata: MessageMetadata::default(),
440 }];
441
442 let output: ProbeAnswersOutput = provider
443 .chat_typed_erased::<ProbeAnswersOutput>(&msgs)
444 .await
445 .map_err(MemoryError::Llm)?;
446
447 Ok(output.answers)
448}
449
450#[derive(Debug, Clone, Serialize, Deserialize)]
452#[serde(default)]
453pub struct CompactionProbeConfig {
454 pub enabled: bool,
456 pub probe_provider: String,
459 pub threshold: f32,
462 pub hard_fail_threshold: f32,
464 pub max_questions: usize,
466 pub timeout_secs: u64,
468 #[serde(default)]
472 pub category_weights: Option<HashMap<ProbeCategory, f32>>,
473}
474
475impl Default for CompactionProbeConfig {
476 fn default() -> Self {
477 Self {
478 enabled: false,
479 probe_provider: String::new(),
480 threshold: 0.6,
481 hard_fail_threshold: 0.35,
482 max_questions: 5,
483 timeout_secs: 15,
484 category_weights: None,
485 }
486 }
487}
488
489pub async fn validate_compaction(
503 provider: AnyProvider,
504 messages: Vec<Message>,
505 summary: String,
506 config: &CompactionProbeConfig,
507) -> Result<Option<CompactionProbeResult>, MemoryError> {
508 if !config.enabled {
509 return Ok(None);
510 }
511
512 let timeout = std::time::Duration::from_secs(config.timeout_secs);
513 let start = Instant::now();
514
515 let result = tokio::time::timeout(timeout, async {
516 run_probe(provider, messages, summary, config).await
517 })
518 .await;
519
520 match result {
521 Ok(inner) => inner,
522 Err(_elapsed) => {
523 tracing::warn!(
524 timeout_secs = config.timeout_secs,
525 "compaction probe timed out — proceeding with compaction"
526 );
527 Ok(None)
528 }
529 }
530 .map(|opt| {
531 opt.map(|mut r| {
532 r.duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
533 r
534 })
535 })
536}
537
538async fn run_probe(
539 provider: AnyProvider,
540 messages: Vec<Message>,
541 summary: String,
542 config: &CompactionProbeConfig,
543) -> Result<Option<CompactionProbeResult>, MemoryError> {
544 if summary.len() < 10 {
545 tracing::warn!(
546 len = summary.len(),
547 "compaction probe: summary too short — skipping probe"
548 );
549 return Ok(None);
550 }
551
552 let questions = generate_probe_questions(&provider, &messages, config.max_questions).await?;
553
554 if questions.len() < 2 {
555 tracing::debug!(
556 count = questions.len(),
557 "compaction probe: fewer than 2 questions generated — skipping probe"
558 );
559 return Ok(None);
560 }
561
562 if config.max_questions >= 4 {
564 use std::collections::HashSet;
565 let covered: HashSet<_> = questions.iter().map(|q| q.category).collect();
566 for cat in [
567 ProbeCategory::Recall,
568 ProbeCategory::Artifact,
569 ProbeCategory::Continuation,
570 ProbeCategory::Decision,
571 ] {
572 if !covered.contains(&cat) {
573 tracing::warn!(
574 category = ?cat,
575 "compaction probe: LLM did not generate questions for category"
576 );
577 }
578 }
579 }
580
581 let answers = answer_probe_questions(&provider, &summary, &questions).await?;
582
583 let (per_question_scores, _simple_avg) = score_answers(&questions, &answers);
584
585 let (category_scores, score) = compute_category_scores(
586 &questions,
587 &per_question_scores,
588 config.category_weights.as_ref(),
589 );
590
591 let verdict = if score >= config.threshold {
592 ProbeVerdict::Pass
593 } else if score >= config.hard_fail_threshold {
594 ProbeVerdict::SoftFail
595 } else {
596 ProbeVerdict::HardFail
597 };
598
599 let model = provider.name().to_owned();
600
601 Ok(Some(CompactionProbeResult {
602 score,
603 category_scores,
604 questions,
605 answers,
606 per_question_scores,
607 verdict,
608 threshold: config.threshold,
609 hard_fail_threshold: config.hard_fail_threshold,
610 model,
611 duration_ms: 0, }))
613}
614
615#[cfg(test)]
616mod tests {
617 use super::*;
618
619 #[test]
622 fn score_perfect_match() {
623 let q = vec![ProbeQuestion {
624 question: "What crate is used?".into(),
625 expected_answer: "thiserror".into(),
626 category: ProbeCategory::Recall,
627 }];
628 let a = vec!["thiserror".into()];
629 let (scores, avg) = score_answers(&q, &a);
630 assert_eq!(scores.len(), 1);
631 assert!((avg - 1.0).abs() < 0.01, "expected ~1.0, got {avg}");
632 }
633
634 #[test]
635 fn score_complete_mismatch() {
636 let q = vec![ProbeQuestion {
637 question: "What file was modified?".into(),
638 expected_answer: "src/auth.rs".into(),
639 ..Default::default()
640 }];
641 let a = vec!["definitely not in the summary".into()];
642 let (scores, avg) = score_answers(&q, &a);
643 assert_eq!(scores.len(), 1);
644 assert!(avg < 0.5, "expected low score, got {avg}");
646 }
647
648 #[test]
649 fn score_refusal_is_zero() {
650 let q = vec![ProbeQuestion {
651 question: "What was the decision?".into(),
652 expected_answer: "Use thiserror for typed errors".into(),
653 ..Default::default()
654 }];
655 for refusal in &[
656 "UNKNOWN",
657 "not mentioned",
658 "N/A",
659 "cannot determine",
660 "No information",
661 ] {
662 let a = vec![(*refusal).to_owned()];
663 let (_, avg) = score_answers(&q, &a);
664 assert!(avg < 0.01, "expected 0 for refusal '{refusal}', got {avg}");
665 }
666 }
667
668 #[test]
669 fn score_paraphrased_answer_above_half() {
670 let q = vec![ProbeQuestion {
673 question: "What error handling crate was chosen?".into(),
674 expected_answer: "Use thiserror for typed errors in library crates".into(),
675 ..Default::default()
676 }];
677 let a = vec!["thiserror was chosen for error types in library crates".into()];
678 let (_, avg) = score_answers(&q, &a);
679 assert!(avg > 0.5, "expected >0.5 for paraphrase, got {avg}");
680 }
681
682 #[test]
683 fn score_empty_strings() {
684 let q = vec![ProbeQuestion {
685 question: "What?".into(),
686 expected_answer: String::new(),
687 ..Default::default()
688 }];
689 let a = vec![String::new()];
690 let (scores, avg) = score_answers(&q, &a);
691 assert_eq!(scores.len(), 1);
692 assert!(
694 (avg - 1.0).abs() < 0.01,
695 "expected 1.0 for empty vs empty, got {avg}"
696 );
697 }
698
699 #[test]
700 fn score_empty_questions_list() {
701 let (scores, avg) = score_answers(&[], &[]);
702 assert!(scores.is_empty());
703 assert!((avg - 0.0).abs() < 0.01);
704 }
705
706 #[test]
707 fn score_file_path_exact() {
708 let q = vec![ProbeQuestion {
709 question: "Which file was modified?".into(),
710 expected_answer: "crates/zeph-memory/src/compaction_probe.rs".into(),
711 ..Default::default()
712 }];
713 let a = vec!["The file crates/zeph-memory/src/compaction_probe.rs was modified.".into()];
714 let (_, avg) = score_answers(&q, &a);
715 assert!(
717 avg > 0.8,
718 "expected high score for file path match, got {avg}"
719 );
720 }
721
722 #[test]
723 fn score_unicode_input() {
724 let q = vec![ProbeQuestion {
725 question: "Что было изменено?".into(),
726 expected_answer: "файл config.toml".into(),
727 ..Default::default()
728 }];
729 let a = vec!["config.toml был изменён".into()];
730 let (scores, _) = score_answers(&q, &a);
732 assert_eq!(scores.len(), 1);
733 }
734
735 #[test]
738 fn verdict_thresholds() {
739 let config = CompactionProbeConfig::default();
740
741 let score = 0.7_f32;
743 let verdict = if score >= config.threshold {
744 ProbeVerdict::Pass
745 } else if score >= config.hard_fail_threshold {
746 ProbeVerdict::SoftFail
747 } else {
748 ProbeVerdict::HardFail
749 };
750 assert_eq!(verdict, ProbeVerdict::Pass);
751
752 let score = 0.5_f32;
754 let verdict = if score >= config.threshold {
755 ProbeVerdict::Pass
756 } else if score >= config.hard_fail_threshold {
757 ProbeVerdict::SoftFail
758 } else {
759 ProbeVerdict::HardFail
760 };
761 assert_eq!(verdict, ProbeVerdict::SoftFail);
762
763 let score = 0.2_f32;
765 let verdict = if score >= config.threshold {
766 ProbeVerdict::Pass
767 } else if score >= config.hard_fail_threshold {
768 ProbeVerdict::SoftFail
769 } else {
770 ProbeVerdict::HardFail
771 };
772 assert_eq!(verdict, ProbeVerdict::HardFail);
773 }
774
775 #[test]
778 fn config_defaults() {
779 let c = CompactionProbeConfig::default();
780 assert!(!c.enabled);
781 assert!(c.probe_provider.is_empty());
782 assert!((c.threshold - 0.6).abs() < 0.001);
783 assert!((c.hard_fail_threshold - 0.35).abs() < 0.001);
784 assert_eq!(c.max_questions, 5);
785 assert_eq!(c.timeout_secs, 15);
786 assert!(c.category_weights.is_none());
787 }
788
789 #[test]
792 fn config_serde_round_trip() {
793 let original = CompactionProbeConfig {
794 enabled: true,
795 probe_provider: "fast".into(),
796 threshold: 0.65,
797 hard_fail_threshold: 0.4,
798 max_questions: 5,
799 timeout_secs: 20,
800 category_weights: None,
801 };
802 let json = serde_json::to_string(&original).expect("serialize");
803 let restored: CompactionProbeConfig = serde_json::from_str(&json).expect("deserialize");
804 assert!(restored.enabled);
805 assert_eq!(restored.probe_provider, "fast");
806 assert!((restored.threshold - 0.65).abs() < 0.001);
807 }
808
809 #[test]
810 fn probe_result_serde_round_trip() {
811 let result = CompactionProbeResult {
812 score: 0.75,
813 category_scores: vec![CategoryScore {
814 category: ProbeCategory::Recall,
815 score: 0.75,
816 probes_run: 1,
817 }],
818 questions: vec![ProbeQuestion {
819 question: "What?".into(),
820 expected_answer: "thiserror".into(),
821 category: ProbeCategory::Recall,
822 }],
823 answers: vec!["thiserror".into()],
824 per_question_scores: vec![1.0],
825 verdict: ProbeVerdict::Pass,
826 threshold: 0.6,
827 hard_fail_threshold: 0.35,
828 model: "haiku".into(),
829 duration_ms: 1234,
830 };
831 let json = serde_json::to_string(&result).expect("serialize");
832 let restored: CompactionProbeResult = serde_json::from_str(&json).expect("deserialize");
833 assert!((restored.score - 0.75).abs() < 0.001);
834 assert_eq!(restored.verdict, ProbeVerdict::Pass);
835 assert_eq!(restored.category_scores.len(), 1);
836 }
837
838 #[test]
839 fn probe_result_backward_compat_no_category_scores() {
840 let json = r#"{"score":0.75,"questions":[],"answers":[],"per_question_scores":[],"verdict":"Pass","threshold":0.6,"hard_fail_threshold":0.35,"model":"haiku","duration_ms":0}"#;
842 let restored: CompactionProbeResult = serde_json::from_str(json).expect("deserialize");
843 assert!(restored.category_scores.is_empty());
844 }
845
846 #[test]
849 fn score_fewer_answers_than_questions() {
850 let questions = vec![
851 ProbeQuestion {
852 question: "What crate?".into(),
853 expected_answer: "thiserror".into(),
854 ..Default::default()
855 },
856 ProbeQuestion {
857 question: "What file?".into(),
858 expected_answer: "src/lib.rs".into(),
859 ..Default::default()
860 },
861 ProbeQuestion {
862 question: "What decision?".into(),
863 expected_answer: "use async traits".into(),
864 ..Default::default()
865 },
866 ];
867 let answers = vec!["thiserror".into()];
869 let (scores, avg) = score_answers(&questions, &answers);
870 assert_eq!(scores.len(), 3);
872 assert!(
874 (scores[0] - 1.0).abs() < 0.01,
875 "first score should be ~1.0, got {}",
876 scores[0]
877 );
878 assert!(
880 scores[1] < 0.5,
881 "second score should be low for missing answer, got {}",
882 scores[1]
883 );
884 assert!(
885 scores[2] < 0.5,
886 "third score should be low for missing answer, got {}",
887 scores[2]
888 );
889 assert!(
891 avg < 0.5,
892 "average should be below 0.5 with 2 missing answers, got {avg}"
893 );
894 }
895
896 #[test]
899 fn verdict_boundary_at_threshold() {
900 let config = CompactionProbeConfig::default();
901
902 let score = config.threshold;
904 let verdict = if score >= config.threshold {
905 ProbeVerdict::Pass
906 } else if score >= config.hard_fail_threshold {
907 ProbeVerdict::SoftFail
908 } else {
909 ProbeVerdict::HardFail
910 };
911 assert_eq!(verdict, ProbeVerdict::Pass);
912
913 let score = config.threshold - f32::EPSILON;
915 let verdict = if score >= config.threshold {
916 ProbeVerdict::Pass
917 } else if score >= config.hard_fail_threshold {
918 ProbeVerdict::SoftFail
919 } else {
920 ProbeVerdict::HardFail
921 };
922 assert_eq!(verdict, ProbeVerdict::SoftFail);
923
924 let score = config.hard_fail_threshold;
926 let verdict = if score >= config.threshold {
927 ProbeVerdict::Pass
928 } else if score >= config.hard_fail_threshold {
929 ProbeVerdict::SoftFail
930 } else {
931 ProbeVerdict::HardFail
932 };
933 assert_eq!(verdict, ProbeVerdict::SoftFail);
934
935 let score = config.hard_fail_threshold - f32::EPSILON;
937 let verdict = if score >= config.threshold {
938 ProbeVerdict::Pass
939 } else if score >= config.hard_fail_threshold {
940 ProbeVerdict::SoftFail
941 } else {
942 ProbeVerdict::HardFail
943 };
944 assert_eq!(verdict, ProbeVerdict::HardFail);
945 }
946
947 #[test]
950 fn config_partial_json_uses_defaults() {
951 let json = r#"{"enabled": true}"#;
953 let c: CompactionProbeConfig =
954 serde_json::from_str(json).expect("deserialize partial json");
955 assert!(c.enabled);
956 assert!(c.probe_provider.is_empty());
957 assert!((c.threshold - 0.6).abs() < 0.001);
958 assert!((c.hard_fail_threshold - 0.35).abs() < 0.001);
959 assert_eq!(c.max_questions, 5);
960 assert_eq!(c.timeout_secs, 15);
961 }
962
963 #[test]
964 fn config_empty_json_uses_all_defaults() {
965 let c: CompactionProbeConfig = serde_json::from_str("{}").expect("deserialize empty json");
966 assert!(!c.enabled);
967 assert!(c.probe_provider.is_empty());
968 }
969
970 #[test]
973 fn probe_category_serde_lowercase() {
974 assert_eq!(
975 serde_json::to_string(&ProbeCategory::Recall).unwrap(),
976 r#""recall""#
977 );
978 assert_eq!(
979 serde_json::to_string(&ProbeCategory::Artifact).unwrap(),
980 r#""artifact""#
981 );
982 assert_eq!(
983 serde_json::to_string(&ProbeCategory::Continuation).unwrap(),
984 r#""continuation""#
985 );
986 assert_eq!(
987 serde_json::to_string(&ProbeCategory::Decision).unwrap(),
988 r#""decision""#
989 );
990 let cat: ProbeCategory = serde_json::from_str(r#""recall""#).unwrap();
991 assert_eq!(cat, ProbeCategory::Recall);
992 }
993
994 #[test]
997 fn category_weights_toml_round_trip() {
998 let toml_str = r#"
999enabled = true
1000probe_provider = "fast"
1001threshold = 0.6
1002hard_fail_threshold = 0.35
1003max_questions = 5
1004timeout_secs = 15
1005[category_weights]
1006recall = 1.5
1007artifact = 1.0
1008continuation = 1.0
1009decision = 0.8
1010"#;
1011 let c: CompactionProbeConfig = toml::from_str(toml_str).expect("deserialize toml");
1012 let weights = c.category_weights.as_ref().unwrap();
1013 assert!((weights[&ProbeCategory::Recall] - 1.5).abs() < 0.001);
1014 assert!((weights[&ProbeCategory::Decision] - 0.8).abs() < 0.001);
1015 }
1016
1017 #[test]
1020 fn category_scores_equal_weights() {
1021 let questions = vec![
1022 ProbeQuestion {
1023 question: "Q1".into(),
1024 expected_answer: "A1".into(),
1025 category: ProbeCategory::Recall,
1026 },
1027 ProbeQuestion {
1028 question: "Q2".into(),
1029 expected_answer: "A2".into(),
1030 category: ProbeCategory::Artifact,
1031 },
1032 ];
1033 let scores = [1.0_f32, 0.0_f32];
1034 let (cats, overall) = compute_category_scores(&questions, &scores, None);
1035 assert_eq!(cats.len(), 2);
1036 assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
1038 }
1039
1040 #[test]
1041 fn category_scores_missing_category_excluded() {
1042 let questions = vec![
1044 ProbeQuestion {
1045 question: "Q1".into(),
1046 expected_answer: "A1".into(),
1047 category: ProbeCategory::Recall,
1048 },
1049 ProbeQuestion {
1050 question: "Q2".into(),
1051 expected_answer: "A2".into(),
1052 category: ProbeCategory::Decision,
1053 },
1054 ];
1055 let scores = [1.0_f32, 0.6_f32];
1056 let (cats, _overall) = compute_category_scores(&questions, &scores, None);
1057 assert_eq!(cats.len(), 2, "only categories with questions present");
1058 let categories: Vec<_> = cats.iter().map(|c| c.category).collect();
1059 assert!(!categories.contains(&ProbeCategory::Artifact));
1060 assert!(!categories.contains(&ProbeCategory::Continuation));
1061 }
1062
1063 #[test]
1064 fn category_scores_custom_weights() {
1065 let questions = vec![
1066 ProbeQuestion {
1067 question: "Q1".into(),
1068 expected_answer: "A1".into(),
1069 category: ProbeCategory::Recall,
1070 },
1071 ProbeQuestion {
1072 question: "Q2".into(),
1073 expected_answer: "A2".into(),
1074 category: ProbeCategory::Decision,
1075 },
1076 ];
1077 let scores = [1.0_f32, 0.0_f32];
1078 let mut weights = HashMap::new();
1079 weights.insert(ProbeCategory::Recall, 2.0_f32);
1080 weights.insert(ProbeCategory::Decision, 1.0_f32);
1081 let (_, overall) = compute_category_scores(&questions, &scores, Some(&weights));
1082 assert!(
1084 (overall - 2.0 / 3.0).abs() < 0.001,
1085 "expected ~0.667, got {overall}"
1086 );
1087 }
1088
1089 #[test]
1090 fn category_scores_all_zero_weights_fallback() {
1091 let questions = vec![
1092 ProbeQuestion {
1093 question: "Q1".into(),
1094 expected_answer: "A1".into(),
1095 category: ProbeCategory::Recall,
1096 },
1097 ProbeQuestion {
1098 question: "Q2".into(),
1099 expected_answer: "A2".into(),
1100 category: ProbeCategory::Artifact,
1101 },
1102 ];
1103 let scores = [1.0_f32, 0.0_f32];
1104 let mut weights = HashMap::new();
1105 weights.insert(ProbeCategory::Recall, 0.0_f32);
1106 weights.insert(ProbeCategory::Artifact, 0.0_f32);
1107 let (_, overall) = compute_category_scores(&questions, &scores, Some(&weights));
1108 assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
1110 }
1111
1112 #[test]
1113 fn category_scores_empty_questions() {
1114 let (cats, overall) = compute_category_scores(&[], &[], None);
1115 assert!(cats.is_empty());
1116 assert!((overall - 0.0).abs() < 0.001);
1117 }
1118
1119 #[test]
1120 fn category_scores_multi_probe_single_category_averages() {
1121 let questions = vec![
1123 ProbeQuestion {
1124 question: "Q1".into(),
1125 expected_answer: "A1".into(),
1126 category: ProbeCategory::Recall,
1127 },
1128 ProbeQuestion {
1129 question: "Q2".into(),
1130 expected_answer: "A2".into(),
1131 category: ProbeCategory::Recall,
1132 },
1133 ProbeQuestion {
1134 question: "Q3".into(),
1135 expected_answer: "A3".into(),
1136 category: ProbeCategory::Recall,
1137 },
1138 ];
1139 let scores = [1.0_f32, 0.0_f32, 0.5_f32];
1140 let (cats, overall) = compute_category_scores(&questions, &scores, None);
1141 assert_eq!(cats.len(), 1, "only one category present");
1142 assert_eq!(cats[0].category, ProbeCategory::Recall);
1143 assert_eq!(cats[0].probes_run, 3);
1144 assert!(
1145 (cats[0].score - 0.5).abs() < 0.001,
1146 "cat score={}",
1147 cats[0].score
1148 );
1149 assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
1151 }
1152
1153 #[test]
1154 fn probe_question_serde_default_category() {
1155 let json = r#"{"question":"What file?","expected_answer":"src/lib.rs"}"#;
1157 let q: ProbeQuestion = serde_json::from_str(json).expect("deserialize");
1158 assert_eq!(q.category, ProbeCategory::Recall);
1159 assert_eq!(q.question, "What file?");
1160 assert_eq!(q.expected_answer, "src/lib.rs");
1161 }
1162
1163 #[test]
1164 fn probe_question_serde_all_categories_round_trip() {
1165 for cat in [
1166 ProbeCategory::Recall,
1167 ProbeCategory::Artifact,
1168 ProbeCategory::Continuation,
1169 ProbeCategory::Decision,
1170 ] {
1171 let q = ProbeQuestion {
1172 question: "test?".into(),
1173 expected_answer: "answer".into(),
1174 category: cat,
1175 };
1176 let json = serde_json::to_string(&q).expect("serialize");
1177 let restored: ProbeQuestion = serde_json::from_str(&json).expect("deserialize");
1178 assert_eq!(restored.category, cat);
1179 }
1180 }
1181}