1use std::collections::HashMap;
12use std::time::Instant;
13
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16use zeph_llm::any::AnyProvider;
17use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, MessagePart, Role};
18
19use crate::error::MemoryError;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
25#[serde(rename_all = "lowercase")]
26pub enum ProbeCategory {
27 Recall,
29 Artifact,
31 Continuation,
33 Decision,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct CategoryScore {
40 pub category: ProbeCategory,
41 pub score: f32,
43 pub probes_run: u32,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
49pub struct ProbeQuestion {
50 pub question: String,
52 pub expected_answer: String,
54 #[serde(default = "default_probe_category")]
56 pub category: ProbeCategory,
57}
58
59fn default_probe_category() -> ProbeCategory {
60 ProbeCategory::Recall
61}
62
63impl Default for ProbeQuestion {
64 fn default() -> Self {
65 Self {
66 question: String::new(),
67 expected_answer: String::new(),
68 category: ProbeCategory::Recall,
69 }
70 }
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
75pub enum ProbeVerdict {
76 Pass,
78 SoftFail,
81 HardFail,
83 Error,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct CompactionProbeResult {
90 pub score: f32,
92 #[serde(default)]
94 pub category_scores: Vec<CategoryScore>,
95 pub questions: Vec<ProbeQuestion>,
97 pub answers: Vec<String>,
99 pub per_question_scores: Vec<f32>,
101 pub verdict: ProbeVerdict,
102 pub threshold: f32,
104 pub hard_fail_threshold: f32,
106 pub model: String,
108 pub duration_ms: u64,
110}
111
112#[derive(Debug, Deserialize, JsonSchema)]
115struct ProbeQuestionsOutput {
116 questions: Vec<ProbeQuestion>,
117}
118
119#[must_use]
125fn compute_category_scores(
126 questions: &[ProbeQuestion],
127 per_question_scores: &[f32],
128 category_weights: Option<&HashMap<ProbeCategory, f32>>,
129) -> (Vec<CategoryScore>, f32) {
130 let mut by_cat: HashMap<ProbeCategory, Vec<f32>> = HashMap::new();
132 for (q, &s) in questions.iter().zip(per_question_scores.iter()) {
133 by_cat.entry(q.category).or_default().push(s);
134 }
135
136 #[allow(clippy::cast_precision_loss)]
137 let category_scores: Vec<CategoryScore> = by_cat
138 .into_iter()
139 .map(|(category, scores)| {
140 let avg = scores.iter().sum::<f32>() / scores.len() as f32;
141 CategoryScore {
142 category,
143 score: avg,
144 #[allow(clippy::cast_possible_truncation)]
145 probes_run: scores.len() as u32,
146 }
147 })
148 .collect();
149
150 if category_scores.is_empty() {
151 return (category_scores, 0.0);
152 }
153
154 let mut weighted_sum = 0.0_f32;
156 let mut weight_total = 0.0_f32;
157 for cs in &category_scores {
158 let raw_w = category_weights
159 .and_then(|m| m.get(&cs.category).copied())
160 .unwrap_or(1.0);
161 if raw_w < 0.0 {
162 tracing::warn!(
163 category = ?cs.category,
164 weight = raw_w,
165 "category_weights contains a negative value — treating as 0.0 (category excluded from scoring)"
166 );
167 }
168 let w = raw_w.max(0.0);
169 weighted_sum += cs.score * w;
170 weight_total += w;
171 }
172
173 let overall = if weight_total > 0.0 {
174 weighted_sum / weight_total
175 } else {
176 #[allow(clippy::cast_precision_loss)]
178 let n = category_scores.len() as f32;
179 category_scores.iter().map(|cs| cs.score).sum::<f32>() / n
180 };
181
182 (category_scores, overall)
183}
184
185#[derive(Debug, Deserialize, JsonSchema)]
186struct ProbeAnswersOutput {
187 answers: Vec<String>,
188}
189
190const REFUSAL_PATTERNS: &[&str] = &[
194 "unknown",
195 "not mentioned",
196 "not found",
197 "n/a",
198 "cannot determine",
199 "no information",
200 "not provided",
201 "not specified",
202 "not stated",
203 "not available",
204];
205
206fn is_refusal(text: &str) -> bool {
207 let lower = text.to_lowercase();
208 REFUSAL_PATTERNS.iter().any(|p| lower.contains(p))
209}
210
211fn normalize_tokens(text: &str) -> Vec<String> {
213 text.to_lowercase()
214 .split(|c: char| !c.is_alphanumeric())
215 .filter(|t| t.len() >= 3)
216 .map(String::from)
217 .collect()
218}
219
220fn jaccard(a: &[String], b: &[String]) -> f32 {
221 if a.is_empty() && b.is_empty() {
222 return 1.0;
223 }
224 let set_a: std::collections::HashSet<&str> = a.iter().map(String::as_str).collect();
225 let set_b: std::collections::HashSet<&str> = b.iter().map(String::as_str).collect();
226 let intersection = set_a.intersection(&set_b).count();
227 let union = set_a.union(&set_b).count();
228 if union == 0 {
229 return 0.0;
230 }
231 #[allow(clippy::cast_precision_loss)]
232 {
233 intersection as f32 / union as f32
234 }
235}
236
237fn score_pair(expected: &str, actual: &str) -> f32 {
239 if is_refusal(actual) {
240 return 0.0;
241 }
242
243 let tokens_e = normalize_tokens(expected);
244 let tokens_a = normalize_tokens(actual);
245
246 if !tokens_e.is_empty() {
248 let set_e: std::collections::HashSet<&str> = tokens_e.iter().map(String::as_str).collect();
249 let set_a: std::collections::HashSet<&str> = tokens_a.iter().map(String::as_str).collect();
250 if set_e.is_subset(&set_a) {
251 return 1.0;
252 }
253 }
254
255 let j_full = jaccard(&tokens_e, &tokens_a);
257
258 let set_e: std::collections::HashSet<&str> = tokens_e.iter().map(String::as_str).collect();
260 let set_a: std::collections::HashSet<&str> = tokens_a.iter().map(String::as_str).collect();
261 let intersection: Vec<String> = set_e
262 .intersection(&set_a)
263 .map(|s| (*s).to_owned())
264 .collect();
265
266 #[allow(clippy::cast_precision_loss)]
267 let j_e = if tokens_e.is_empty() {
268 0.0_f32
269 } else {
270 intersection.len() as f32 / tokens_e.len() as f32
271 };
272 #[allow(clippy::cast_precision_loss)]
273 let j_a = if tokens_a.is_empty() {
274 0.0_f32
275 } else {
276 intersection.len() as f32 / tokens_a.len() as f32
277 };
278
279 j_full.max(j_e).max(j_a)
280}
281
282#[must_use]
286pub fn score_answers(questions: &[ProbeQuestion], answers: &[String]) -> (Vec<f32>, f32) {
287 if questions.is_empty() {
288 return (vec![], 0.0);
289 }
290 let scores: Vec<f32> = questions
291 .iter()
292 .zip(answers.iter().chain(std::iter::repeat(&String::new())))
293 .map(|(q, a)| score_pair(&q.expected_answer, a))
294 .collect();
295 #[allow(clippy::cast_precision_loss)]
296 let avg = if scores.is_empty() {
297 0.0
298 } else {
299 scores.iter().sum::<f32>() / scores.len() as f32
300 };
301 (scores, avg)
302}
303
304fn truncate_tool_bodies(messages: &[Message]) -> Vec<Message> {
308 messages
309 .iter()
310 .map(|m| {
311 let mut msg = m.clone();
312 for part in &mut msg.parts {
313 if let MessagePart::ToolOutput { body, .. } = part {
314 if body.len() <= 500 {
315 continue;
316 }
317 body.truncate(500);
318 body.push('\u{2026}');
319 }
320 }
321 msg.rebuild_content();
322 msg
323 })
324 .collect()
325}
326
327pub async fn generate_probe_questions(
336 provider: &AnyProvider,
337 messages: &[Message],
338 max_questions: usize,
339) -> Result<Vec<ProbeQuestion>, MemoryError> {
340 let truncated = truncate_tool_bodies(messages);
341
342 let mut history = String::new();
343 for msg in &truncated {
344 let role = match msg.role {
345 Role::User => "user",
346 Role::Assistant => "assistant",
347 Role::System => "system",
348 };
349 history.push_str(role);
350 history.push_str(": ");
351 history.push_str(&msg.content);
352 history.push('\n');
353 }
354
355 let prompt = format!(
356 "Given the following conversation excerpt, generate {max_questions} factual questions \
357 that test whether a summary preserves the most important concrete details.\n\
358 \n\
359 You MUST generate at least one question per category when max_questions >= 4. \
360 If the conversation lacks information for a category, generate a question noting that absence.\n\
361 \n\
362 Categories:\n\
363 - recall: Specific facts that survived (file paths, function names, values). \
364 Example: \"What file was modified?\"\n\
365 - artifact: Which files/tools/URLs the agent used. \
366 Example: \"Which tool was executed?\"\n\
367 - continuation: Next steps, blockers, open questions. \
368 Example: \"What is the next step?\"\n\
369 - decision: Past reasoning traces (why X over Y, trade-offs). \
370 Example: \"Why was X chosen over Y?\"\n\
371 \n\
372 Do NOT generate questions about:\n\
373 - Raw tool output content (compiler warnings, test output line numbers)\n\
374 - Intermediate debugging steps that were superseded\n\
375 - Opinions or reasoning that cannot be verified\n\
376 \n\
377 Each question must have a single unambiguous expected answer extractable from the text.\n\
378 \n\
379 Conversation:\n{history}\n\
380 \n\
381 Respond in JSON with schema: {{\"questions\": [{{\"question\": \"...\", \
382 \"expected_answer\": \"...\", \"category\": \"recall|artifact|continuation|decision\"}}]}}"
383 );
384
385 let msgs = [Message {
386 role: Role::User,
387 content: prompt,
388 parts: vec![],
389 metadata: MessageMetadata::default(),
390 }];
391
392 let mut output: ProbeQuestionsOutput = provider
393 .chat_typed_erased::<ProbeQuestionsOutput>(&msgs)
394 .await
395 .map_err(MemoryError::Llm)?;
396
397 output.questions.truncate(max_questions);
399
400 Ok(output.questions)
401}
402
403pub async fn answer_probe_questions(
409 provider: &AnyProvider,
410 summary: &str,
411 questions: &[ProbeQuestion],
412) -> Result<Vec<String>, MemoryError> {
413 let mut numbered = String::new();
414 for (i, q) in questions.iter().enumerate() {
415 use std::fmt::Write as _;
416 let _ = writeln!(numbered, "{}. {}", i + 1, q.question);
417 }
418
419 let prompt = format!(
420 "Given the following summary of a conversation, answer each question using ONLY \
421 information present in the summary. If the answer is not in the summary, respond \
422 with \"UNKNOWN\".\n\
423 \n\
424 Summary:\n{summary}\n\
425 \n\
426 Questions:\n{numbered}\n\
427 \n\
428 Respond in JSON with schema: {{\"answers\": [\"answer1\", \"answer2\", ...]}}"
429 );
430
431 let msgs = [Message {
432 role: Role::User,
433 content: prompt,
434 parts: vec![],
435 metadata: MessageMetadata::default(),
436 }];
437
438 let output: ProbeAnswersOutput = provider
439 .chat_typed_erased::<ProbeAnswersOutput>(&msgs)
440 .await
441 .map_err(MemoryError::Llm)?;
442
443 Ok(output.answers)
444}
445
446#[derive(Debug, Clone, Serialize, Deserialize)]
448#[serde(default)]
449pub struct CompactionProbeConfig {
450 pub enabled: bool,
452 pub probe_provider: String,
455 pub threshold: f32,
458 pub hard_fail_threshold: f32,
460 pub max_questions: usize,
462 pub timeout_secs: u64,
464 #[serde(default)]
468 pub category_weights: Option<HashMap<ProbeCategory, f32>>,
469}
470
471impl Default for CompactionProbeConfig {
472 fn default() -> Self {
473 Self {
474 enabled: false,
475 probe_provider: String::new(),
476 threshold: 0.6,
477 hard_fail_threshold: 0.35,
478 max_questions: 5,
479 timeout_secs: 15,
480 category_weights: None,
481 }
482 }
483}
484
485pub async fn validate_compaction(
499 provider: &AnyProvider,
500 messages: &[Message],
501 summary: &str,
502 config: &CompactionProbeConfig,
503) -> Result<Option<CompactionProbeResult>, MemoryError> {
504 if !config.enabled {
505 return Ok(None);
506 }
507
508 let timeout = std::time::Duration::from_secs(config.timeout_secs);
509 let start = Instant::now();
510
511 let result = tokio::time::timeout(timeout, async {
512 run_probe(provider, messages, summary, config).await
513 })
514 .await;
515
516 match result {
517 Ok(inner) => inner,
518 Err(_elapsed) => {
519 tracing::warn!(
520 timeout_secs = config.timeout_secs,
521 "compaction probe timed out — proceeding with compaction"
522 );
523 Ok(None)
524 }
525 }
526 .map(|opt| {
527 opt.map(|mut r| {
528 r.duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
529 r
530 })
531 })
532}
533
534async fn run_probe(
535 provider: &AnyProvider,
536 messages: &[Message],
537 summary: &str,
538 config: &CompactionProbeConfig,
539) -> Result<Option<CompactionProbeResult>, MemoryError> {
540 if summary.len() < 10 {
541 tracing::warn!(
542 len = summary.len(),
543 "compaction probe: summary too short — skipping probe"
544 );
545 return Ok(None);
546 }
547
548 let questions = generate_probe_questions(provider, messages, config.max_questions).await?;
549
550 if questions.len() < 2 {
551 tracing::debug!(
552 count = questions.len(),
553 "compaction probe: fewer than 2 questions generated — skipping probe"
554 );
555 return Ok(None);
556 }
557
558 if config.max_questions >= 4 {
560 use std::collections::HashSet;
561 let covered: HashSet<_> = questions.iter().map(|q| q.category).collect();
562 for cat in [
563 ProbeCategory::Recall,
564 ProbeCategory::Artifact,
565 ProbeCategory::Continuation,
566 ProbeCategory::Decision,
567 ] {
568 if !covered.contains(&cat) {
569 tracing::warn!(
570 category = ?cat,
571 "compaction probe: LLM did not generate questions for category"
572 );
573 }
574 }
575 }
576
577 let answers = answer_probe_questions(provider, summary, &questions).await?;
578
579 let (per_question_scores, _simple_avg) = score_answers(&questions, &answers);
580
581 let (category_scores, score) = compute_category_scores(
582 &questions,
583 &per_question_scores,
584 config.category_weights.as_ref(),
585 );
586
587 let verdict = if score >= config.threshold {
588 ProbeVerdict::Pass
589 } else if score >= config.hard_fail_threshold {
590 ProbeVerdict::SoftFail
591 } else {
592 ProbeVerdict::HardFail
593 };
594
595 let model = provider.name().to_owned();
596
597 Ok(Some(CompactionProbeResult {
598 score,
599 category_scores,
600 questions,
601 answers,
602 per_question_scores,
603 verdict,
604 threshold: config.threshold,
605 hard_fail_threshold: config.hard_fail_threshold,
606 model,
607 duration_ms: 0, }))
609}
610
611#[cfg(test)]
612mod tests {
613 use super::*;
614
615 #[test]
618 fn score_perfect_match() {
619 let q = vec![ProbeQuestion {
620 question: "What crate is used?".into(),
621 expected_answer: "thiserror".into(),
622 category: ProbeCategory::Recall,
623 }];
624 let a = vec!["thiserror".into()];
625 let (scores, avg) = score_answers(&q, &a);
626 assert_eq!(scores.len(), 1);
627 assert!((avg - 1.0).abs() < 0.01, "expected ~1.0, got {avg}");
628 }
629
630 #[test]
631 fn score_complete_mismatch() {
632 let q = vec![ProbeQuestion {
633 question: "What file was modified?".into(),
634 expected_answer: "src/auth.rs".into(),
635 ..Default::default()
636 }];
637 let a = vec!["definitely not in the summary".into()];
638 let (scores, avg) = score_answers(&q, &a);
639 assert_eq!(scores.len(), 1);
640 assert!(avg < 0.5, "expected low score, got {avg}");
642 }
643
644 #[test]
645 fn score_refusal_is_zero() {
646 let q = vec![ProbeQuestion {
647 question: "What was the decision?".into(),
648 expected_answer: "Use thiserror for typed errors".into(),
649 ..Default::default()
650 }];
651 for refusal in &[
652 "UNKNOWN",
653 "not mentioned",
654 "N/A",
655 "cannot determine",
656 "No information",
657 ] {
658 let a = vec![(*refusal).to_owned()];
659 let (_, avg) = score_answers(&q, &a);
660 assert!(avg < 0.01, "expected 0 for refusal '{refusal}', got {avg}");
661 }
662 }
663
664 #[test]
665 fn score_paraphrased_answer_above_half() {
666 let q = vec![ProbeQuestion {
669 question: "What error handling crate was chosen?".into(),
670 expected_answer: "Use thiserror for typed errors in library crates".into(),
671 ..Default::default()
672 }];
673 let a = vec!["thiserror was chosen for error types in library crates".into()];
674 let (_, avg) = score_answers(&q, &a);
675 assert!(avg > 0.5, "expected >0.5 for paraphrase, got {avg}");
676 }
677
678 #[test]
679 fn score_empty_strings() {
680 let q = vec![ProbeQuestion {
681 question: "What?".into(),
682 expected_answer: String::new(),
683 ..Default::default()
684 }];
685 let a = vec![String::new()];
686 let (scores, avg) = score_answers(&q, &a);
687 assert_eq!(scores.len(), 1);
688 assert!(
690 (avg - 1.0).abs() < 0.01,
691 "expected 1.0 for empty vs empty, got {avg}"
692 );
693 }
694
695 #[test]
696 fn score_empty_questions_list() {
697 let (scores, avg) = score_answers(&[], &[]);
698 assert!(scores.is_empty());
699 assert!((avg - 0.0).abs() < 0.01);
700 }
701
702 #[test]
703 fn score_file_path_exact() {
704 let q = vec![ProbeQuestion {
705 question: "Which file was modified?".into(),
706 expected_answer: "crates/zeph-memory/src/compaction_probe.rs".into(),
707 ..Default::default()
708 }];
709 let a = vec!["The file crates/zeph-memory/src/compaction_probe.rs was modified.".into()];
710 let (_, avg) = score_answers(&q, &a);
711 assert!(
713 avg > 0.8,
714 "expected high score for file path match, got {avg}"
715 );
716 }
717
718 #[test]
719 fn score_unicode_input() {
720 let q = vec![ProbeQuestion {
721 question: "Что было изменено?".into(),
722 expected_answer: "файл config.toml".into(),
723 ..Default::default()
724 }];
725 let a = vec!["config.toml был изменён".into()];
726 let (scores, _) = score_answers(&q, &a);
728 assert_eq!(scores.len(), 1);
729 }
730
731 #[test]
734 fn verdict_thresholds() {
735 let config = CompactionProbeConfig::default();
736
737 let score = 0.7_f32;
739 let verdict = if score >= config.threshold {
740 ProbeVerdict::Pass
741 } else if score >= config.hard_fail_threshold {
742 ProbeVerdict::SoftFail
743 } else {
744 ProbeVerdict::HardFail
745 };
746 assert_eq!(verdict, ProbeVerdict::Pass);
747
748 let score = 0.5_f32;
750 let verdict = if score >= config.threshold {
751 ProbeVerdict::Pass
752 } else if score >= config.hard_fail_threshold {
753 ProbeVerdict::SoftFail
754 } else {
755 ProbeVerdict::HardFail
756 };
757 assert_eq!(verdict, ProbeVerdict::SoftFail);
758
759 let score = 0.2_f32;
761 let verdict = if score >= config.threshold {
762 ProbeVerdict::Pass
763 } else if score >= config.hard_fail_threshold {
764 ProbeVerdict::SoftFail
765 } else {
766 ProbeVerdict::HardFail
767 };
768 assert_eq!(verdict, ProbeVerdict::HardFail);
769 }
770
771 #[test]
774 fn config_defaults() {
775 let c = CompactionProbeConfig::default();
776 assert!(!c.enabled);
777 assert!(c.probe_provider.is_empty());
778 assert!((c.threshold - 0.6).abs() < 0.001);
779 assert!((c.hard_fail_threshold - 0.35).abs() < 0.001);
780 assert_eq!(c.max_questions, 5);
781 assert_eq!(c.timeout_secs, 15);
782 assert!(c.category_weights.is_none());
783 }
784
785 #[test]
788 fn config_serde_round_trip() {
789 let original = CompactionProbeConfig {
790 enabled: true,
791 probe_provider: "fast".into(),
792 threshold: 0.65,
793 hard_fail_threshold: 0.4,
794 max_questions: 5,
795 timeout_secs: 20,
796 category_weights: None,
797 };
798 let json = serde_json::to_string(&original).expect("serialize");
799 let restored: CompactionProbeConfig = serde_json::from_str(&json).expect("deserialize");
800 assert!(restored.enabled);
801 assert_eq!(restored.probe_provider, "fast");
802 assert!((restored.threshold - 0.65).abs() < 0.001);
803 }
804
805 #[test]
806 fn probe_result_serde_round_trip() {
807 let result = CompactionProbeResult {
808 score: 0.75,
809 category_scores: vec![CategoryScore {
810 category: ProbeCategory::Recall,
811 score: 0.75,
812 probes_run: 1,
813 }],
814 questions: vec![ProbeQuestion {
815 question: "What?".into(),
816 expected_answer: "thiserror".into(),
817 category: ProbeCategory::Recall,
818 }],
819 answers: vec!["thiserror".into()],
820 per_question_scores: vec![1.0],
821 verdict: ProbeVerdict::Pass,
822 threshold: 0.6,
823 hard_fail_threshold: 0.35,
824 model: "haiku".into(),
825 duration_ms: 1234,
826 };
827 let json = serde_json::to_string(&result).expect("serialize");
828 let restored: CompactionProbeResult = serde_json::from_str(&json).expect("deserialize");
829 assert!((restored.score - 0.75).abs() < 0.001);
830 assert_eq!(restored.verdict, ProbeVerdict::Pass);
831 assert_eq!(restored.category_scores.len(), 1);
832 }
833
834 #[test]
835 fn probe_result_backward_compat_no_category_scores() {
836 let json = r#"{"score":0.75,"questions":[],"answers":[],"per_question_scores":[],"verdict":"Pass","threshold":0.6,"hard_fail_threshold":0.35,"model":"haiku","duration_ms":0}"#;
838 let restored: CompactionProbeResult = serde_json::from_str(json).expect("deserialize");
839 assert!(restored.category_scores.is_empty());
840 }
841
842 #[test]
845 fn score_fewer_answers_than_questions() {
846 let questions = vec![
847 ProbeQuestion {
848 question: "What crate?".into(),
849 expected_answer: "thiserror".into(),
850 ..Default::default()
851 },
852 ProbeQuestion {
853 question: "What file?".into(),
854 expected_answer: "src/lib.rs".into(),
855 ..Default::default()
856 },
857 ProbeQuestion {
858 question: "What decision?".into(),
859 expected_answer: "use async traits".into(),
860 ..Default::default()
861 },
862 ];
863 let answers = vec!["thiserror".into()];
865 let (scores, avg) = score_answers(&questions, &answers);
866 assert_eq!(scores.len(), 3);
868 assert!(
870 (scores[0] - 1.0).abs() < 0.01,
871 "first score should be ~1.0, got {}",
872 scores[0]
873 );
874 assert!(
876 scores[1] < 0.5,
877 "second score should be low for missing answer, got {}",
878 scores[1]
879 );
880 assert!(
881 scores[2] < 0.5,
882 "third score should be low for missing answer, got {}",
883 scores[2]
884 );
885 assert!(
887 avg < 0.5,
888 "average should be below 0.5 with 2 missing answers, got {avg}"
889 );
890 }
891
892 #[test]
895 fn verdict_boundary_at_threshold() {
896 let config = CompactionProbeConfig::default();
897
898 let score = config.threshold;
900 let verdict = if score >= config.threshold {
901 ProbeVerdict::Pass
902 } else if score >= config.hard_fail_threshold {
903 ProbeVerdict::SoftFail
904 } else {
905 ProbeVerdict::HardFail
906 };
907 assert_eq!(verdict, ProbeVerdict::Pass);
908
909 let score = config.threshold - f32::EPSILON;
911 let verdict = if score >= config.threshold {
912 ProbeVerdict::Pass
913 } else if score >= config.hard_fail_threshold {
914 ProbeVerdict::SoftFail
915 } else {
916 ProbeVerdict::HardFail
917 };
918 assert_eq!(verdict, ProbeVerdict::SoftFail);
919
920 let score = config.hard_fail_threshold;
922 let verdict = if score >= config.threshold {
923 ProbeVerdict::Pass
924 } else if score >= config.hard_fail_threshold {
925 ProbeVerdict::SoftFail
926 } else {
927 ProbeVerdict::HardFail
928 };
929 assert_eq!(verdict, ProbeVerdict::SoftFail);
930
931 let score = config.hard_fail_threshold - f32::EPSILON;
933 let verdict = if score >= config.threshold {
934 ProbeVerdict::Pass
935 } else if score >= config.hard_fail_threshold {
936 ProbeVerdict::SoftFail
937 } else {
938 ProbeVerdict::HardFail
939 };
940 assert_eq!(verdict, ProbeVerdict::HardFail);
941 }
942
943 #[test]
946 fn config_partial_json_uses_defaults() {
947 let json = r#"{"enabled": true}"#;
949 let c: CompactionProbeConfig =
950 serde_json::from_str(json).expect("deserialize partial json");
951 assert!(c.enabled);
952 assert!(c.probe_provider.is_empty());
953 assert!((c.threshold - 0.6).abs() < 0.001);
954 assert!((c.hard_fail_threshold - 0.35).abs() < 0.001);
955 assert_eq!(c.max_questions, 5);
956 assert_eq!(c.timeout_secs, 15);
957 }
958
959 #[test]
960 fn config_empty_json_uses_all_defaults() {
961 let c: CompactionProbeConfig = serde_json::from_str("{}").expect("deserialize empty json");
962 assert!(!c.enabled);
963 assert!(c.probe_provider.is_empty());
964 }
965
966 #[test]
969 fn probe_category_serde_lowercase() {
970 assert_eq!(
971 serde_json::to_string(&ProbeCategory::Recall).unwrap(),
972 r#""recall""#
973 );
974 assert_eq!(
975 serde_json::to_string(&ProbeCategory::Artifact).unwrap(),
976 r#""artifact""#
977 );
978 assert_eq!(
979 serde_json::to_string(&ProbeCategory::Continuation).unwrap(),
980 r#""continuation""#
981 );
982 assert_eq!(
983 serde_json::to_string(&ProbeCategory::Decision).unwrap(),
984 r#""decision""#
985 );
986 let cat: ProbeCategory = serde_json::from_str(r#""recall""#).unwrap();
987 assert_eq!(cat, ProbeCategory::Recall);
988 }
989
990 #[test]
993 fn category_weights_toml_round_trip() {
994 let toml_str = r#"
995enabled = true
996probe_provider = "fast"
997threshold = 0.6
998hard_fail_threshold = 0.35
999max_questions = 5
1000timeout_secs = 15
1001[category_weights]
1002recall = 1.5
1003artifact = 1.0
1004continuation = 1.0
1005decision = 0.8
1006"#;
1007 let c: CompactionProbeConfig = toml::from_str(toml_str).expect("deserialize toml");
1008 let weights = c.category_weights.as_ref().unwrap();
1009 assert!((weights[&ProbeCategory::Recall] - 1.5).abs() < 0.001);
1010 assert!((weights[&ProbeCategory::Decision] - 0.8).abs() < 0.001);
1011 }
1012
1013 #[test]
1016 fn category_scores_equal_weights() {
1017 let questions = vec![
1018 ProbeQuestion {
1019 question: "Q1".into(),
1020 expected_answer: "A1".into(),
1021 category: ProbeCategory::Recall,
1022 },
1023 ProbeQuestion {
1024 question: "Q2".into(),
1025 expected_answer: "A2".into(),
1026 category: ProbeCategory::Artifact,
1027 },
1028 ];
1029 let scores = [1.0_f32, 0.0_f32];
1030 let (cats, overall) = compute_category_scores(&questions, &scores, None);
1031 assert_eq!(cats.len(), 2);
1032 assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
1034 }
1035
1036 #[test]
1037 fn category_scores_missing_category_excluded() {
1038 let questions = vec![
1040 ProbeQuestion {
1041 question: "Q1".into(),
1042 expected_answer: "A1".into(),
1043 category: ProbeCategory::Recall,
1044 },
1045 ProbeQuestion {
1046 question: "Q2".into(),
1047 expected_answer: "A2".into(),
1048 category: ProbeCategory::Decision,
1049 },
1050 ];
1051 let scores = [1.0_f32, 0.6_f32];
1052 let (cats, _overall) = compute_category_scores(&questions, &scores, None);
1053 assert_eq!(cats.len(), 2, "only categories with questions present");
1054 let categories: Vec<_> = cats.iter().map(|c| c.category).collect();
1055 assert!(!categories.contains(&ProbeCategory::Artifact));
1056 assert!(!categories.contains(&ProbeCategory::Continuation));
1057 }
1058
1059 #[test]
1060 fn category_scores_custom_weights() {
1061 let questions = vec![
1062 ProbeQuestion {
1063 question: "Q1".into(),
1064 expected_answer: "A1".into(),
1065 category: ProbeCategory::Recall,
1066 },
1067 ProbeQuestion {
1068 question: "Q2".into(),
1069 expected_answer: "A2".into(),
1070 category: ProbeCategory::Decision,
1071 },
1072 ];
1073 let scores = [1.0_f32, 0.0_f32];
1074 let mut weights = HashMap::new();
1075 weights.insert(ProbeCategory::Recall, 2.0_f32);
1076 weights.insert(ProbeCategory::Decision, 1.0_f32);
1077 let (_, overall) = compute_category_scores(&questions, &scores, Some(&weights));
1078 assert!(
1080 (overall - 2.0 / 3.0).abs() < 0.001,
1081 "expected ~0.667, got {overall}"
1082 );
1083 }
1084
1085 #[test]
1086 fn category_scores_all_zero_weights_fallback() {
1087 let questions = vec![
1088 ProbeQuestion {
1089 question: "Q1".into(),
1090 expected_answer: "A1".into(),
1091 category: ProbeCategory::Recall,
1092 },
1093 ProbeQuestion {
1094 question: "Q2".into(),
1095 expected_answer: "A2".into(),
1096 category: ProbeCategory::Artifact,
1097 },
1098 ];
1099 let scores = [1.0_f32, 0.0_f32];
1100 let mut weights = HashMap::new();
1101 weights.insert(ProbeCategory::Recall, 0.0_f32);
1102 weights.insert(ProbeCategory::Artifact, 0.0_f32);
1103 let (_, overall) = compute_category_scores(&questions, &scores, Some(&weights));
1104 assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
1106 }
1107
1108 #[test]
1109 fn category_scores_empty_questions() {
1110 let (cats, overall) = compute_category_scores(&[], &[], None);
1111 assert!(cats.is_empty());
1112 assert!((overall - 0.0).abs() < 0.001);
1113 }
1114
1115 #[test]
1116 fn category_scores_multi_probe_single_category_averages() {
1117 let questions = vec![
1119 ProbeQuestion {
1120 question: "Q1".into(),
1121 expected_answer: "A1".into(),
1122 category: ProbeCategory::Recall,
1123 },
1124 ProbeQuestion {
1125 question: "Q2".into(),
1126 expected_answer: "A2".into(),
1127 category: ProbeCategory::Recall,
1128 },
1129 ProbeQuestion {
1130 question: "Q3".into(),
1131 expected_answer: "A3".into(),
1132 category: ProbeCategory::Recall,
1133 },
1134 ];
1135 let scores = [1.0_f32, 0.0_f32, 0.5_f32];
1136 let (cats, overall) = compute_category_scores(&questions, &scores, None);
1137 assert_eq!(cats.len(), 1, "only one category present");
1138 assert_eq!(cats[0].category, ProbeCategory::Recall);
1139 assert_eq!(cats[0].probes_run, 3);
1140 assert!(
1141 (cats[0].score - 0.5).abs() < 0.001,
1142 "cat score={}",
1143 cats[0].score
1144 );
1145 assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
1147 }
1148
1149 #[test]
1150 fn probe_question_serde_default_category() {
1151 let json = r#"{"question":"What file?","expected_answer":"src/lib.rs"}"#;
1153 let q: ProbeQuestion = serde_json::from_str(json).expect("deserialize");
1154 assert_eq!(q.category, ProbeCategory::Recall);
1155 assert_eq!(q.question, "What file?");
1156 assert_eq!(q.expected_answer, "src/lib.rs");
1157 }
1158
1159 #[test]
1160 fn probe_question_serde_all_categories_round_trip() {
1161 for cat in [
1162 ProbeCategory::Recall,
1163 ProbeCategory::Artifact,
1164 ProbeCategory::Continuation,
1165 ProbeCategory::Decision,
1166 ] {
1167 let q = ProbeQuestion {
1168 question: "test?".into(),
1169 expected_answer: "answer".into(),
1170 category: cat,
1171 };
1172 let json = serde_json::to_string(&q).expect("serialize");
1173 let restored: ProbeQuestion = serde_json::from_str(&json).expect("deserialize");
1174 assert_eq!(restored.category, cat);
1175 }
1176 }
1177}