1use std::collections::HashMap;
12use std::time::Instant;
13
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16use zeph_llm::any::AnyProvider;
17use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, MessagePart, Role};
18
19use crate::error::MemoryError;
20
21use zeph_config::memory::{CompactionProbeConfig, ProbeCategory};
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct CategoryScore {
28 pub category: ProbeCategory,
29 pub score: f32,
31 pub probes_run: u32,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
37pub struct ProbeQuestion {
38 pub question: String,
40 pub expected_answer: String,
42 #[serde(default = "default_probe_category")]
44 pub category: ProbeCategory,
45}
46
47fn default_probe_category() -> ProbeCategory {
48 ProbeCategory::Recall
49}
50
51impl Default for ProbeQuestion {
52 fn default() -> Self {
53 Self {
54 question: String::new(),
55 expected_answer: String::new(),
56 category: ProbeCategory::Recall,
57 }
58 }
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
63pub enum ProbeVerdict {
64 Pass,
66 SoftFail,
69 HardFail,
71 Error,
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct CompactionProbeResult {
78 pub score: f32,
80 #[serde(default)]
82 pub category_scores: Vec<CategoryScore>,
83 pub questions: Vec<ProbeQuestion>,
85 pub answers: Vec<String>,
87 pub per_question_scores: Vec<f32>,
89 pub verdict: ProbeVerdict,
90 pub threshold: f32,
92 pub hard_fail_threshold: f32,
94 pub model: String,
96 pub duration_ms: u64,
98}
99
100#[derive(Debug, Deserialize, JsonSchema)]
103struct ProbeQuestionsOutput {
104 questions: Vec<ProbeQuestion>,
105}
106
107#[must_use]
113fn compute_category_scores(
114 questions: &[ProbeQuestion],
115 per_question_scores: &[f32],
116 category_weights: Option<&HashMap<ProbeCategory, f32>>,
117) -> (Vec<CategoryScore>, f32) {
118 let mut by_cat: HashMap<ProbeCategory, Vec<f32>> = HashMap::new();
120 for (q, &s) in questions.iter().zip(per_question_scores.iter()) {
121 by_cat.entry(q.category).or_default().push(s);
122 }
123
124 #[allow(clippy::cast_precision_loss)]
125 let category_scores: Vec<CategoryScore> = by_cat
126 .into_iter()
127 .map(|(category, scores)| {
128 let avg = scores.iter().sum::<f32>() / scores.len() as f32;
129 CategoryScore {
130 category,
131 score: avg,
132 #[allow(clippy::cast_possible_truncation)]
133 probes_run: scores.len() as u32,
134 }
135 })
136 .collect();
137
138 if category_scores.is_empty() {
139 return (category_scores, 0.0);
140 }
141
142 let mut weighted_sum = 0.0_f32;
144 let mut weight_total = 0.0_f32;
145 for cs in &category_scores {
146 let raw_w = category_weights
147 .and_then(|m| m.get(&cs.category).copied())
148 .unwrap_or(1.0);
149 if raw_w < 0.0 {
150 tracing::warn!(
151 category = ?cs.category,
152 weight = raw_w,
153 "category_weights contains a negative value — treating as 0.0 (category excluded from scoring)"
154 );
155 }
156 let w = raw_w.max(0.0);
157 weighted_sum += cs.score * w;
158 weight_total += w;
159 }
160
161 let overall = if weight_total > 0.0 {
162 weighted_sum / weight_total
163 } else {
164 #[allow(clippy::cast_precision_loss)]
166 let n = category_scores.len() as f32;
167 category_scores.iter().map(|cs| cs.score).sum::<f32>() / n
168 };
169
170 (category_scores, overall)
171}
172
173#[derive(Debug, Deserialize, JsonSchema)]
174struct ProbeAnswersOutput {
175 answers: Vec<String>,
176}
177
178const REFUSAL_PATTERNS: &[&str] = &[
182 "unknown",
183 "not mentioned",
184 "not found",
185 "n/a",
186 "cannot determine",
187 "no information",
188 "not provided",
189 "not specified",
190 "not stated",
191 "not available",
192];
193
194fn is_refusal(text: &str) -> bool {
195 let lower = text.to_lowercase();
196 REFUSAL_PATTERNS.iter().any(|p| lower.contains(p))
197}
198
199fn normalize_tokens(text: &str) -> Vec<String> {
201 text.to_lowercase()
202 .split(|c: char| !c.is_alphanumeric())
203 .filter(|t| t.len() >= 3)
204 .map(String::from)
205 .collect()
206}
207
208fn jaccard(a: &[String], b: &[String]) -> f32 {
209 if a.is_empty() && b.is_empty() {
210 return 1.0;
211 }
212 let set_a: std::collections::HashSet<&str> = a.iter().map(String::as_str).collect();
213 let set_b: std::collections::HashSet<&str> = b.iter().map(String::as_str).collect();
214 let intersection = set_a.intersection(&set_b).count();
215 let union = set_a.union(&set_b).count();
216 if union == 0 {
217 return 0.0;
218 }
219 #[allow(clippy::cast_precision_loss)]
220 {
221 intersection as f32 / union as f32
222 }
223}
224
225fn score_pair(expected: &str, actual: &str) -> f32 {
227 if is_refusal(actual) {
228 return 0.0;
229 }
230
231 let tokens_e = normalize_tokens(expected);
232 let tokens_a = normalize_tokens(actual);
233
234 if !tokens_e.is_empty() {
236 let set_e: std::collections::HashSet<&str> = tokens_e.iter().map(String::as_str).collect();
237 let set_a: std::collections::HashSet<&str> = tokens_a.iter().map(String::as_str).collect();
238 if set_e.is_subset(&set_a) {
239 return 1.0;
240 }
241 }
242
243 let j_full = jaccard(&tokens_e, &tokens_a);
245
246 let set_e: std::collections::HashSet<&str> = tokens_e.iter().map(String::as_str).collect();
248 let set_a: std::collections::HashSet<&str> = tokens_a.iter().map(String::as_str).collect();
249 let intersection: Vec<String> = set_e
250 .intersection(&set_a)
251 .map(|s| (*s).to_owned())
252 .collect();
253
254 #[allow(clippy::cast_precision_loss)]
255 let j_e = if tokens_e.is_empty() {
256 0.0_f32
257 } else {
258 intersection.len() as f32 / tokens_e.len() as f32
259 };
260 #[allow(clippy::cast_precision_loss)]
261 let j_a = if tokens_a.is_empty() {
262 0.0_f32
263 } else {
264 intersection.len() as f32 / tokens_a.len() as f32
265 };
266
267 j_full.max(j_e).max(j_a)
268}
269
270#[must_use]
274pub fn score_answers(questions: &[ProbeQuestion], answers: &[String]) -> (Vec<f32>, f32) {
275 if questions.is_empty() {
276 return (vec![], 0.0);
277 }
278 let scores: Vec<f32> = questions
279 .iter()
280 .zip(answers.iter().chain(std::iter::repeat(&String::new())))
281 .map(|(q, a)| score_pair(&q.expected_answer, a))
282 .collect();
283 #[allow(clippy::cast_precision_loss)]
284 let avg = if scores.is_empty() {
285 0.0
286 } else {
287 scores.iter().sum::<f32>() / scores.len() as f32
288 };
289 (scores, avg)
290}
291
292fn truncate_tool_bodies(messages: &[Message]) -> Vec<Message> {
296 messages
297 .iter()
298 .map(|m| {
299 let mut msg = m.clone();
300 for part in &mut msg.parts {
301 if let MessagePart::ToolOutput { body, .. } = part {
302 if body.len() <= 500 {
303 continue;
304 }
305 body.truncate(500);
306 body.push('\u{2026}');
307 }
308 }
309 msg.rebuild_content();
310 msg
311 })
312 .collect()
313}
314
315#[cfg_attr(
324 feature = "profiling",
325 tracing::instrument(name = "memory.compaction_probe", skip_all)
326)]
327pub async fn generate_probe_questions(
328 provider: &AnyProvider,
329 messages: &[Message],
330 max_questions: usize,
331) -> Result<Vec<ProbeQuestion>, MemoryError> {
332 let truncated = truncate_tool_bodies(messages);
333
334 let mut history = String::new();
335 for msg in &truncated {
336 let role = match msg.role {
337 Role::User => "user",
338 Role::Assistant => "assistant",
339 Role::System => "system",
340 };
341 history.push_str(role);
342 history.push_str(": ");
343 history.push_str(&msg.content);
344 history.push('\n');
345 }
346
347 let prompt = format!(
348 "Given the following conversation excerpt, generate {max_questions} factual questions \
349 that test whether a summary preserves the most important concrete details.\n\
350 \n\
351 You MUST generate at least one question per category when max_questions >= 4. \
352 If the conversation lacks information for a category, generate a question noting that absence.\n\
353 \n\
354 Categories:\n\
355 - recall: Specific facts that survived (file paths, function names, values). \
356 Example: \"What file was modified?\"\n\
357 - artifact: Which files/tools/URLs the agent used. \
358 Example: \"Which tool was executed?\"\n\
359 - continuation: Next steps, blockers, open questions. \
360 Example: \"What is the next step?\"\n\
361 - decision: Past reasoning traces (why X over Y, trade-offs). \
362 Example: \"Why was X chosen over Y?\"\n\
363 \n\
364 Do NOT generate questions about:\n\
365 - Raw tool output content (compiler warnings, test output line numbers)\n\
366 - Intermediate debugging steps that were superseded\n\
367 - Opinions or reasoning that cannot be verified\n\
368 \n\
369 Each question must have a single unambiguous expected answer extractable from the text.\n\
370 \n\
371 Conversation:\n{history}\n\
372 \n\
373 Respond in JSON with schema: {{\"questions\": [{{\"question\": \"...\", \
374 \"expected_answer\": \"...\", \"category\": \"recall|artifact|continuation|decision\"}}]}}"
375 );
376
377 let msgs = [Message {
378 role: Role::User,
379 content: prompt,
380 parts: vec![],
381 metadata: MessageMetadata::default(),
382 }];
383
384 let mut output: ProbeQuestionsOutput = provider
385 .chat_typed_erased::<ProbeQuestionsOutput>(&msgs)
386 .await
387 .map_err(MemoryError::Llm)?;
388
389 output.questions.truncate(max_questions);
391
392 Ok(output.questions)
393}
394
395pub async fn answer_probe_questions(
401 provider: &AnyProvider,
402 summary: &str,
403 questions: &[ProbeQuestion],
404) -> Result<Vec<String>, MemoryError> {
405 let mut numbered = String::new();
406 for (i, q) in questions.iter().enumerate() {
407 use std::fmt::Write as _;
408 let _ = writeln!(numbered, "{}. {}", i + 1, q.question);
409 }
410
411 let prompt = format!(
412 "Given the following summary of a conversation, answer each question using ONLY \
413 information present in the summary. If the answer is not in the summary, respond \
414 with \"UNKNOWN\".\n\
415 \n\
416 Summary:\n{summary}\n\
417 \n\
418 Questions:\n{numbered}\n\
419 \n\
420 Respond in JSON with schema: {{\"answers\": [\"answer1\", \"answer2\", ...]}}"
421 );
422
423 let msgs = [Message {
424 role: Role::User,
425 content: prompt,
426 parts: vec![],
427 metadata: MessageMetadata::default(),
428 }];
429
430 let output: ProbeAnswersOutput = provider
431 .chat_typed_erased::<ProbeAnswersOutput>(&msgs)
432 .await
433 .map_err(MemoryError::Llm)?;
434
435 Ok(output.answers)
436}
437
438pub async fn validate_compaction(
452 provider: AnyProvider,
453 messages: Vec<Message>,
454 summary: String,
455 config: &CompactionProbeConfig,
456) -> Result<Option<CompactionProbeResult>, MemoryError> {
457 if !config.enabled {
458 return Ok(None);
459 }
460
461 let timeout = std::time::Duration::from_secs(config.timeout_secs);
462 let start = Instant::now();
463
464 let result = tokio::time::timeout(timeout, async {
465 run_probe(provider, messages, summary, config).await
466 })
467 .await;
468
469 match result {
470 Ok(inner) => inner,
471 Err(_elapsed) => {
472 tracing::warn!(
473 timeout_secs = config.timeout_secs,
474 "compaction probe timed out — proceeding with compaction"
475 );
476 Ok(None)
477 }
478 }
479 .map(|opt| {
480 opt.map(|mut r| {
481 r.duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
482 r
483 })
484 })
485}
486
487async fn run_probe(
488 provider: AnyProvider,
489 messages: Vec<Message>,
490 summary: String,
491 config: &CompactionProbeConfig,
492) -> Result<Option<CompactionProbeResult>, MemoryError> {
493 if summary.len() < 10 {
494 tracing::warn!(
495 len = summary.len(),
496 "compaction probe: summary too short — skipping probe"
497 );
498 return Ok(None);
499 }
500
501 let questions = generate_probe_questions(&provider, &messages, config.max_questions).await?;
502
503 if questions.len() < 2 {
504 tracing::debug!(
505 count = questions.len(),
506 "compaction probe: fewer than 2 questions generated — skipping probe"
507 );
508 return Ok(None);
509 }
510
511 if config.max_questions >= 4 {
513 use std::collections::HashSet;
514 let covered: HashSet<_> = questions.iter().map(|q| q.category).collect();
515 for cat in [
516 ProbeCategory::Recall,
517 ProbeCategory::Artifact,
518 ProbeCategory::Continuation,
519 ProbeCategory::Decision,
520 ] {
521 if !covered.contains(&cat) {
522 tracing::warn!(
523 category = ?cat,
524 "compaction probe: LLM did not generate questions for category"
525 );
526 }
527 }
528 }
529
530 let answers = answer_probe_questions(&provider, &summary, &questions).await?;
531
532 let (per_question_scores, _simple_avg) = score_answers(&questions, &answers);
533
534 let (category_scores, score) = compute_category_scores(
535 &questions,
536 &per_question_scores,
537 config.category_weights.as_ref(),
538 );
539
540 let verdict = if score >= config.threshold {
541 ProbeVerdict::Pass
542 } else if score >= config.hard_fail_threshold {
543 ProbeVerdict::SoftFail
544 } else {
545 ProbeVerdict::HardFail
546 };
547
548 let model = provider.name().to_owned();
549
550 Ok(Some(CompactionProbeResult {
551 score,
552 category_scores,
553 questions,
554 answers,
555 per_question_scores,
556 verdict,
557 threshold: config.threshold,
558 hard_fail_threshold: config.hard_fail_threshold,
559 model,
560 duration_ms: 0, }))
562}
563
564#[cfg(test)]
565mod tests {
566 use super::*;
567
568 #[test]
571 fn score_perfect_match() {
572 let q = vec![ProbeQuestion {
573 question: "What crate is used?".into(),
574 expected_answer: "thiserror".into(),
575 category: ProbeCategory::Recall,
576 }];
577 let a = vec!["thiserror".into()];
578 let (scores, avg) = score_answers(&q, &a);
579 assert_eq!(scores.len(), 1);
580 assert!((avg - 1.0).abs() < 0.01, "expected ~1.0, got {avg}");
581 }
582
583 #[test]
584 fn score_complete_mismatch() {
585 let q = vec![ProbeQuestion {
586 question: "What file was modified?".into(),
587 expected_answer: "src/auth.rs".into(),
588 ..Default::default()
589 }];
590 let a = vec!["definitely not in the summary".into()];
591 let (scores, avg) = score_answers(&q, &a);
592 assert_eq!(scores.len(), 1);
593 assert!(avg < 0.5, "expected low score, got {avg}");
595 }
596
597 #[test]
598 fn score_refusal_is_zero() {
599 let q = vec![ProbeQuestion {
600 question: "What was the decision?".into(),
601 expected_answer: "Use thiserror for typed errors".into(),
602 ..Default::default()
603 }];
604 for refusal in &[
605 "UNKNOWN",
606 "not mentioned",
607 "N/A",
608 "cannot determine",
609 "No information",
610 ] {
611 let a = vec![(*refusal).to_owned()];
612 let (_, avg) = score_answers(&q, &a);
613 assert!(avg < 0.01, "expected 0 for refusal '{refusal}', got {avg}");
614 }
615 }
616
617 #[test]
618 fn score_paraphrased_answer_above_half() {
619 let q = vec![ProbeQuestion {
622 question: "What error handling crate was chosen?".into(),
623 expected_answer: "Use thiserror for typed errors in library crates".into(),
624 ..Default::default()
625 }];
626 let a = vec!["thiserror was chosen for error types in library crates".into()];
627 let (_, avg) = score_answers(&q, &a);
628 assert!(avg > 0.5, "expected >0.5 for paraphrase, got {avg}");
629 }
630
631 #[test]
632 fn score_empty_strings() {
633 let q = vec![ProbeQuestion {
634 question: "What?".into(),
635 expected_answer: String::new(),
636 ..Default::default()
637 }];
638 let a = vec![String::new()];
639 let (scores, avg) = score_answers(&q, &a);
640 assert_eq!(scores.len(), 1);
641 assert!(
643 (avg - 1.0).abs() < 0.01,
644 "expected 1.0 for empty vs empty, got {avg}"
645 );
646 }
647
648 #[test]
649 fn score_empty_questions_list() {
650 let (scores, avg) = score_answers(&[], &[]);
651 assert!(scores.is_empty());
652 assert!((avg - 0.0).abs() < 0.01);
653 }
654
655 #[test]
656 fn score_file_path_exact() {
657 let q = vec![ProbeQuestion {
658 question: "Which file was modified?".into(),
659 expected_answer: "crates/zeph-memory/src/compaction_probe.rs".into(),
660 ..Default::default()
661 }];
662 let a = vec!["The file crates/zeph-memory/src/compaction_probe.rs was modified.".into()];
663 let (_, avg) = score_answers(&q, &a);
664 assert!(
666 avg > 0.8,
667 "expected high score for file path match, got {avg}"
668 );
669 }
670
671 #[test]
672 fn score_unicode_input() {
673 let q = vec![ProbeQuestion {
674 question: "Что было изменено?".into(),
675 expected_answer: "файл config.toml".into(),
676 ..Default::default()
677 }];
678 let a = vec!["config.toml был изменён".into()];
679 let (scores, _) = score_answers(&q, &a);
681 assert_eq!(scores.len(), 1);
682 }
683
684 #[test]
687 fn verdict_thresholds() {
688 let config = CompactionProbeConfig::default();
689
690 let score = 0.7_f32;
692 let verdict = if score >= config.threshold {
693 ProbeVerdict::Pass
694 } else if score >= config.hard_fail_threshold {
695 ProbeVerdict::SoftFail
696 } else {
697 ProbeVerdict::HardFail
698 };
699 assert_eq!(verdict, ProbeVerdict::Pass);
700
701 let score = 0.5_f32;
703 let verdict = if score >= config.threshold {
704 ProbeVerdict::Pass
705 } else if score >= config.hard_fail_threshold {
706 ProbeVerdict::SoftFail
707 } else {
708 ProbeVerdict::HardFail
709 };
710 assert_eq!(verdict, ProbeVerdict::SoftFail);
711
712 let score = 0.2_f32;
714 let verdict = if score >= config.threshold {
715 ProbeVerdict::Pass
716 } else if score >= config.hard_fail_threshold {
717 ProbeVerdict::SoftFail
718 } else {
719 ProbeVerdict::HardFail
720 };
721 assert_eq!(verdict, ProbeVerdict::HardFail);
722 }
723
724 #[test]
727 fn config_defaults() {
728 let c = CompactionProbeConfig::default();
729 assert!(!c.enabled);
730 assert!(c.probe_provider.is_none());
731 assert!((c.threshold - 0.6).abs() < 0.001);
732 assert!((c.hard_fail_threshold - 0.35).abs() < 0.001);
733 assert_eq!(c.max_questions, 5);
734 assert_eq!(c.timeout_secs, 15);
735 assert!(c.category_weights.is_none());
736 }
737
738 #[test]
741 fn config_serde_round_trip() {
742 let original = CompactionProbeConfig {
743 enabled: true,
744 probe_provider: Some(zeph_config::ProviderName::new("fast")),
745 threshold: 0.65,
746 hard_fail_threshold: 0.4,
747 max_questions: 5,
748 timeout_secs: 20,
749 category_weights: None,
750 };
751 let json = serde_json::to_string(&original).expect("serialize");
752 let restored: CompactionProbeConfig = serde_json::from_str(&json).expect("deserialize");
753 assert!(restored.enabled);
754 assert_eq!(restored.probe_provider.as_deref(), Some("fast"));
755 assert!((restored.threshold - 0.65).abs() < 0.001);
756 }
757
758 #[test]
759 fn probe_result_serde_round_trip() {
760 let result = CompactionProbeResult {
761 score: 0.75,
762 category_scores: vec![CategoryScore {
763 category: ProbeCategory::Recall,
764 score: 0.75,
765 probes_run: 1,
766 }],
767 questions: vec![ProbeQuestion {
768 question: "What?".into(),
769 expected_answer: "thiserror".into(),
770 category: ProbeCategory::Recall,
771 }],
772 answers: vec!["thiserror".into()],
773 per_question_scores: vec![1.0],
774 verdict: ProbeVerdict::Pass,
775 threshold: 0.6,
776 hard_fail_threshold: 0.35,
777 model: "haiku".into(),
778 duration_ms: 1234,
779 };
780 let json = serde_json::to_string(&result).expect("serialize");
781 let restored: CompactionProbeResult = serde_json::from_str(&json).expect("deserialize");
782 assert!((restored.score - 0.75).abs() < 0.001);
783 assert_eq!(restored.verdict, ProbeVerdict::Pass);
784 assert_eq!(restored.category_scores.len(), 1);
785 }
786
787 #[test]
788 fn probe_result_backward_compat_no_category_scores() {
789 let json = r#"{"score":0.75,"questions":[],"answers":[],"per_question_scores":[],"verdict":"Pass","threshold":0.6,"hard_fail_threshold":0.35,"model":"haiku","duration_ms":0}"#;
791 let restored: CompactionProbeResult = serde_json::from_str(json).expect("deserialize");
792 assert!(restored.category_scores.is_empty());
793 }
794
795 #[test]
798 fn score_fewer_answers_than_questions() {
799 let questions = vec![
800 ProbeQuestion {
801 question: "What crate?".into(),
802 expected_answer: "thiserror".into(),
803 ..Default::default()
804 },
805 ProbeQuestion {
806 question: "What file?".into(),
807 expected_answer: "src/lib.rs".into(),
808 ..Default::default()
809 },
810 ProbeQuestion {
811 question: "What decision?".into(),
812 expected_answer: "use async traits".into(),
813 ..Default::default()
814 },
815 ];
816 let answers = vec!["thiserror".into()];
818 let (scores, avg) = score_answers(&questions, &answers);
819 assert_eq!(scores.len(), 3);
821 assert!(
823 (scores[0] - 1.0).abs() < 0.01,
824 "first score should be ~1.0, got {}",
825 scores[0]
826 );
827 assert!(
829 scores[1] < 0.5,
830 "second score should be low for missing answer, got {}",
831 scores[1]
832 );
833 assert!(
834 scores[2] < 0.5,
835 "third score should be low for missing answer, got {}",
836 scores[2]
837 );
838 assert!(
840 avg < 0.5,
841 "average should be below 0.5 with 2 missing answers, got {avg}"
842 );
843 }
844
845 #[test]
848 fn verdict_boundary_at_threshold() {
849 let config = CompactionProbeConfig::default();
850
851 let score = config.threshold;
853 let verdict = if score >= config.threshold {
854 ProbeVerdict::Pass
855 } else if score >= config.hard_fail_threshold {
856 ProbeVerdict::SoftFail
857 } else {
858 ProbeVerdict::HardFail
859 };
860 assert_eq!(verdict, ProbeVerdict::Pass);
861
862 let score = config.threshold - f32::EPSILON;
864 let verdict = if score >= config.threshold {
865 ProbeVerdict::Pass
866 } else if score >= config.hard_fail_threshold {
867 ProbeVerdict::SoftFail
868 } else {
869 ProbeVerdict::HardFail
870 };
871 assert_eq!(verdict, ProbeVerdict::SoftFail);
872
873 let score = config.hard_fail_threshold;
875 let verdict = if score >= config.threshold {
876 ProbeVerdict::Pass
877 } else if score >= config.hard_fail_threshold {
878 ProbeVerdict::SoftFail
879 } else {
880 ProbeVerdict::HardFail
881 };
882 assert_eq!(verdict, ProbeVerdict::SoftFail);
883
884 let score = config.hard_fail_threshold - f32::EPSILON;
886 let verdict = if score >= config.threshold {
887 ProbeVerdict::Pass
888 } else if score >= config.hard_fail_threshold {
889 ProbeVerdict::SoftFail
890 } else {
891 ProbeVerdict::HardFail
892 };
893 assert_eq!(verdict, ProbeVerdict::HardFail);
894 }
895
896 #[test]
899 fn config_partial_json_uses_defaults() {
900 let json = r#"{"enabled": true}"#;
902 let c: CompactionProbeConfig =
903 serde_json::from_str(json).expect("deserialize partial json");
904 assert!(c.enabled);
905 assert!(c.probe_provider.is_none());
906 assert!((c.threshold - 0.6).abs() < 0.001);
907 assert!((c.hard_fail_threshold - 0.35).abs() < 0.001);
908 assert_eq!(c.max_questions, 5);
909 assert_eq!(c.timeout_secs, 15);
910 }
911
912 #[test]
913 fn config_empty_json_uses_all_defaults() {
914 let c: CompactionProbeConfig = serde_json::from_str("{}").expect("deserialize empty json");
915 assert!(!c.enabled);
916 assert!(c.probe_provider.is_none());
917 }
918
919 #[test]
922 fn probe_category_serde_lowercase() {
923 assert_eq!(
924 serde_json::to_string(&ProbeCategory::Recall).unwrap(),
925 r#""recall""#
926 );
927 assert_eq!(
928 serde_json::to_string(&ProbeCategory::Artifact).unwrap(),
929 r#""artifact""#
930 );
931 assert_eq!(
932 serde_json::to_string(&ProbeCategory::Continuation).unwrap(),
933 r#""continuation""#
934 );
935 assert_eq!(
936 serde_json::to_string(&ProbeCategory::Decision).unwrap(),
937 r#""decision""#
938 );
939 let cat: ProbeCategory = serde_json::from_str(r#""recall""#).unwrap();
940 assert_eq!(cat, ProbeCategory::Recall);
941 }
942
943 #[test]
946 fn category_weights_toml_round_trip() {
947 let toml_str = r#"
948enabled = true
949probe_provider = "fast"
950threshold = 0.6
951hard_fail_threshold = 0.35
952max_questions = 5
953timeout_secs = 15
954[category_weights]
955recall = 1.5
956artifact = 1.0
957continuation = 1.0
958decision = 0.8
959"#;
960 let c: CompactionProbeConfig = toml::from_str(toml_str).expect("deserialize toml");
961 let weights = c.category_weights.as_ref().unwrap();
962 assert!((weights[&ProbeCategory::Recall] - 1.5).abs() < 0.001);
963 assert!((weights[&ProbeCategory::Decision] - 0.8).abs() < 0.001);
964 }
965
966 #[test]
969 fn category_scores_equal_weights() {
970 let questions = vec![
971 ProbeQuestion {
972 question: "Q1".into(),
973 expected_answer: "A1".into(),
974 category: ProbeCategory::Recall,
975 },
976 ProbeQuestion {
977 question: "Q2".into(),
978 expected_answer: "A2".into(),
979 category: ProbeCategory::Artifact,
980 },
981 ];
982 let scores = [1.0_f32, 0.0_f32];
983 let (cats, overall) = compute_category_scores(&questions, &scores, None);
984 assert_eq!(cats.len(), 2);
985 assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
987 }
988
989 #[test]
990 fn category_scores_missing_category_excluded() {
991 let questions = vec![
993 ProbeQuestion {
994 question: "Q1".into(),
995 expected_answer: "A1".into(),
996 category: ProbeCategory::Recall,
997 },
998 ProbeQuestion {
999 question: "Q2".into(),
1000 expected_answer: "A2".into(),
1001 category: ProbeCategory::Decision,
1002 },
1003 ];
1004 let scores = [1.0_f32, 0.6_f32];
1005 let (cats, _overall) = compute_category_scores(&questions, &scores, None);
1006 assert_eq!(cats.len(), 2, "only categories with questions present");
1007 let categories: Vec<_> = cats.iter().map(|c| c.category).collect();
1008 assert!(!categories.contains(&ProbeCategory::Artifact));
1009 assert!(!categories.contains(&ProbeCategory::Continuation));
1010 }
1011
1012 #[test]
1013 fn category_scores_custom_weights() {
1014 let questions = vec![
1015 ProbeQuestion {
1016 question: "Q1".into(),
1017 expected_answer: "A1".into(),
1018 category: ProbeCategory::Recall,
1019 },
1020 ProbeQuestion {
1021 question: "Q2".into(),
1022 expected_answer: "A2".into(),
1023 category: ProbeCategory::Decision,
1024 },
1025 ];
1026 let scores = [1.0_f32, 0.0_f32];
1027 let mut weights = HashMap::new();
1028 weights.insert(ProbeCategory::Recall, 2.0_f32);
1029 weights.insert(ProbeCategory::Decision, 1.0_f32);
1030 let (_, overall) = compute_category_scores(&questions, &scores, Some(&weights));
1031 assert!(
1033 (overall - 2.0 / 3.0).abs() < 0.001,
1034 "expected ~0.667, got {overall}"
1035 );
1036 }
1037
1038 #[test]
1039 fn category_scores_all_zero_weights_fallback() {
1040 let questions = vec![
1041 ProbeQuestion {
1042 question: "Q1".into(),
1043 expected_answer: "A1".into(),
1044 category: ProbeCategory::Recall,
1045 },
1046 ProbeQuestion {
1047 question: "Q2".into(),
1048 expected_answer: "A2".into(),
1049 category: ProbeCategory::Artifact,
1050 },
1051 ];
1052 let scores = [1.0_f32, 0.0_f32];
1053 let mut weights = HashMap::new();
1054 weights.insert(ProbeCategory::Recall, 0.0_f32);
1055 weights.insert(ProbeCategory::Artifact, 0.0_f32);
1056 let (_, overall) = compute_category_scores(&questions, &scores, Some(&weights));
1057 assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
1059 }
1060
1061 #[test]
1062 fn category_scores_empty_questions() {
1063 let (cats, overall) = compute_category_scores(&[], &[], None);
1064 assert!(cats.is_empty());
1065 assert!((overall - 0.0).abs() < 0.001);
1066 }
1067
1068 #[test]
1069 fn category_scores_multi_probe_single_category_averages() {
1070 let questions = vec![
1072 ProbeQuestion {
1073 question: "Q1".into(),
1074 expected_answer: "A1".into(),
1075 category: ProbeCategory::Recall,
1076 },
1077 ProbeQuestion {
1078 question: "Q2".into(),
1079 expected_answer: "A2".into(),
1080 category: ProbeCategory::Recall,
1081 },
1082 ProbeQuestion {
1083 question: "Q3".into(),
1084 expected_answer: "A3".into(),
1085 category: ProbeCategory::Recall,
1086 },
1087 ];
1088 let scores = [1.0_f32, 0.0_f32, 0.5_f32];
1089 let (cats, overall) = compute_category_scores(&questions, &scores, None);
1090 assert_eq!(cats.len(), 1, "only one category present");
1091 assert_eq!(cats[0].category, ProbeCategory::Recall);
1092 assert_eq!(cats[0].probes_run, 3);
1093 assert!(
1094 (cats[0].score - 0.5).abs() < 0.001,
1095 "cat score={}",
1096 cats[0].score
1097 );
1098 assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
1100 }
1101
1102 #[test]
1103 fn probe_question_serde_default_category() {
1104 let json = r#"{"question":"What file?","expected_answer":"src/lib.rs"}"#;
1106 let q: ProbeQuestion = serde_json::from_str(json).expect("deserialize");
1107 assert_eq!(q.category, ProbeCategory::Recall);
1108 assert_eq!(q.question, "What file?");
1109 assert_eq!(q.expected_answer, "src/lib.rs");
1110 }
1111
1112 #[test]
1113 fn probe_question_serde_all_categories_round_trip() {
1114 for cat in [
1115 ProbeCategory::Recall,
1116 ProbeCategory::Artifact,
1117 ProbeCategory::Continuation,
1118 ProbeCategory::Decision,
1119 ] {
1120 let q = ProbeQuestion {
1121 question: "test?".into(),
1122 expected_answer: "answer".into(),
1123 category: cat,
1124 };
1125 let json = serde_json::to_string(&q).expect("serialize");
1126 let restored: ProbeQuestion = serde_json::from_str(&json).expect("deserialize");
1127 assert_eq!(restored.category, cat);
1128 }
1129 }
1130}