1use std::collections::HashMap;
12use std::time::Instant;
13
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16use zeph_llm::any::AnyProvider;
17use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, MessagePart, Role};
18
19use crate::error::MemoryError;
20
21use zeph_config::memory::{CompactionProbeConfig, ProbeCategory};
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct CategoryScore {
28 pub category: ProbeCategory,
29 pub score: f32,
31 pub probes_run: u32,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
37pub struct ProbeQuestion {
38 pub question: String,
40 pub expected_answer: String,
42 #[serde(default = "default_probe_category")]
44 pub category: ProbeCategory,
45}
46
47fn default_probe_category() -> ProbeCategory {
48 ProbeCategory::Recall
49}
50
51impl Default for ProbeQuestion {
52 fn default() -> Self {
53 Self {
54 question: String::new(),
55 expected_answer: String::new(),
56 category: ProbeCategory::Recall,
57 }
58 }
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
63#[non_exhaustive]
64pub enum ProbeVerdict {
65 Pass,
67 SoftFail,
70 HardFail,
72 Error,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct CompactionProbeResult {
79 pub score: f32,
81 #[serde(default)]
83 pub category_scores: Vec<CategoryScore>,
84 pub questions: Vec<ProbeQuestion>,
86 pub answers: Vec<String>,
88 pub per_question_scores: Vec<f32>,
90 pub verdict: ProbeVerdict,
91 pub threshold: f32,
93 pub hard_fail_threshold: f32,
95 pub model: String,
97 pub duration_ms: u64,
99}
100
101#[derive(Debug, Deserialize, JsonSchema)]
104struct ProbeQuestionsOutput {
105 questions: Vec<ProbeQuestion>,
106}
107
108#[must_use]
114fn compute_category_scores(
115 questions: &[ProbeQuestion],
116 per_question_scores: &[f32],
117 category_weights: Option<&HashMap<ProbeCategory, f32>>,
118) -> (Vec<CategoryScore>, f32) {
119 let mut by_cat: HashMap<ProbeCategory, Vec<f32>> = HashMap::new();
121 for (q, &s) in questions.iter().zip(per_question_scores.iter()) {
122 by_cat.entry(q.category).or_default().push(s);
123 }
124
125 #[allow(clippy::cast_precision_loss)]
126 let category_scores: Vec<CategoryScore> = by_cat
127 .into_iter()
128 .map(|(category, scores)| {
129 let avg = scores.iter().sum::<f32>() / scores.len() as f32;
130 CategoryScore {
131 category,
132 score: avg,
133 #[allow(clippy::cast_possible_truncation)]
134 probes_run: scores.len() as u32,
135 }
136 })
137 .collect();
138
139 if category_scores.is_empty() {
140 return (category_scores, 0.0);
141 }
142
143 let mut weighted_sum = 0.0_f32;
145 let mut weight_total = 0.0_f32;
146 for cs in &category_scores {
147 let raw_w = category_weights
148 .and_then(|m| m.get(&cs.category).copied())
149 .unwrap_or(1.0);
150 if raw_w < 0.0 {
151 tracing::warn!(
152 category = ?cs.category,
153 weight = raw_w,
154 "category_weights contains a negative value — treating as 0.0 (category excluded from scoring)"
155 );
156 }
157 let w = raw_w.max(0.0);
158 weighted_sum += cs.score * w;
159 weight_total += w;
160 }
161
162 let overall = if weight_total > 0.0 {
163 weighted_sum / weight_total
164 } else {
165 #[allow(clippy::cast_precision_loss)]
167 let n = category_scores.len() as f32;
168 category_scores.iter().map(|cs| cs.score).sum::<f32>() / n
169 };
170
171 (category_scores, overall)
172}
173
174#[derive(Debug, Deserialize, JsonSchema)]
175struct ProbeAnswersOutput {
176 answers: Vec<String>,
177}
178
179const REFUSAL_PATTERNS: &[&str] = &[
183 "unknown",
184 "not mentioned",
185 "not found",
186 "n/a",
187 "cannot determine",
188 "no information",
189 "not provided",
190 "not specified",
191 "not stated",
192 "not available",
193];
194
195fn is_refusal(text: &str) -> bool {
196 let lower = text.to_lowercase();
197 REFUSAL_PATTERNS.iter().any(|p| lower.contains(p))
198}
199
200fn normalize_tokens(text: &str) -> Vec<String> {
202 text.to_lowercase()
203 .split(|c: char| !c.is_alphanumeric())
204 .filter(|t| t.len() >= 3)
205 .map(String::from)
206 .collect()
207}
208
209fn jaccard(a: &[String], b: &[String]) -> f32 {
210 if a.is_empty() && b.is_empty() {
211 return 1.0;
212 }
213 let set_a: std::collections::HashSet<&str> = a.iter().map(String::as_str).collect();
214 let set_b: std::collections::HashSet<&str> = b.iter().map(String::as_str).collect();
215 let intersection = set_a.intersection(&set_b).count();
216 let union = set_a.union(&set_b).count();
217 if union == 0 {
218 return 0.0;
219 }
220 #[allow(clippy::cast_precision_loss)]
221 {
222 intersection as f32 / union as f32
223 }
224}
225
226fn score_pair(expected: &str, actual: &str) -> f32 {
228 if is_refusal(actual) {
229 return 0.0;
230 }
231
232 let tokens_e = normalize_tokens(expected);
233 let tokens_a = normalize_tokens(actual);
234
235 if !tokens_e.is_empty() {
237 let set_e: std::collections::HashSet<&str> = tokens_e.iter().map(String::as_str).collect();
238 let set_a: std::collections::HashSet<&str> = tokens_a.iter().map(String::as_str).collect();
239 if set_e.is_subset(&set_a) {
240 return 1.0;
241 }
242 }
243
244 let j_full = jaccard(&tokens_e, &tokens_a);
246
247 let set_e: std::collections::HashSet<&str> = tokens_e.iter().map(String::as_str).collect();
249 let set_a: std::collections::HashSet<&str> = tokens_a.iter().map(String::as_str).collect();
250 let intersection: Vec<String> = set_e
251 .intersection(&set_a)
252 .map(|s| (*s).to_owned())
253 .collect();
254
255 #[allow(clippy::cast_precision_loss)]
256 let j_e = if tokens_e.is_empty() {
257 0.0_f32
258 } else {
259 intersection.len() as f32 / tokens_e.len() as f32
260 };
261 #[allow(clippy::cast_precision_loss)]
262 let j_a = if tokens_a.is_empty() {
263 0.0_f32
264 } else {
265 intersection.len() as f32 / tokens_a.len() as f32
266 };
267
268 j_full.max(j_e).max(j_a)
269}
270
271#[must_use]
275pub fn score_answers(questions: &[ProbeQuestion], answers: &[String]) -> (Vec<f32>, f32) {
276 if questions.is_empty() {
277 return (vec![], 0.0);
278 }
279 let scores: Vec<f32> = questions
280 .iter()
281 .zip(answers.iter().chain(std::iter::repeat(&String::new())))
282 .map(|(q, a)| score_pair(&q.expected_answer, a))
283 .collect();
284 #[allow(clippy::cast_precision_loss)]
285 let avg = if scores.is_empty() {
286 0.0
287 } else {
288 scores.iter().sum::<f32>() / scores.len() as f32
289 };
290 (scores, avg)
291}
292
293fn truncate_tool_bodies(messages: &[Message]) -> Vec<Message> {
297 messages
298 .iter()
299 .map(|m| {
300 let mut msg = m.clone();
301 for part in &mut msg.parts {
302 if let MessagePart::ToolOutput { body, .. } = part {
303 if body.len() <= 500 {
304 continue;
305 }
306 body.truncate(500);
307 body.push('\u{2026}');
308 }
309 }
310 msg.rebuild_content();
311 msg
312 })
313 .collect()
314}
315
316#[cfg_attr(
325 feature = "profiling",
326 tracing::instrument(name = "memory.compaction_probe", skip_all)
327)]
328pub async fn generate_probe_questions(
329 provider: &AnyProvider,
330 messages: &[Message],
331 max_questions: usize,
332) -> Result<Vec<ProbeQuestion>, MemoryError> {
333 let truncated = truncate_tool_bodies(messages);
334
335 let mut history = String::new();
336 for msg in &truncated {
337 let role = match msg.role {
338 Role::Assistant => "assistant",
339 Role::System => "system",
340 Role::User | _ => "user",
341 };
342 history.push_str(role);
343 history.push_str(": ");
344 history.push_str(&msg.content);
345 history.push('\n');
346 }
347
348 let prompt = format!(
349 "Given the following conversation excerpt, generate {max_questions} factual questions \
350 that test whether a summary preserves the most important concrete details.\n\
351 \n\
352 You MUST generate at least one question per category when max_questions >= 4. \
353 If the conversation lacks information for a category, generate a question noting that absence.\n\
354 \n\
355 Categories:\n\
356 - recall: Specific facts that survived (file paths, function names, values). \
357 Example: \"What file was modified?\"\n\
358 - artifact: Which files/tools/URLs the agent used. \
359 Example: \"Which tool was executed?\"\n\
360 - continuation: Next steps, blockers, open questions. \
361 Example: \"What is the next step?\"\n\
362 - decision: Past reasoning traces (why X over Y, trade-offs). \
363 Example: \"Why was X chosen over Y?\"\n\
364 \n\
365 Do NOT generate questions about:\n\
366 - Raw tool output content (compiler warnings, test output line numbers)\n\
367 - Intermediate debugging steps that were superseded\n\
368 - Opinions or reasoning that cannot be verified\n\
369 \n\
370 Each question must have a single unambiguous expected answer extractable from the text.\n\
371 \n\
372 Conversation:\n{history}\n\
373 \n\
374 Respond in JSON with schema: {{\"questions\": [{{\"question\": \"...\", \
375 \"expected_answer\": \"...\", \"category\": \"recall|artifact|continuation|decision\"}}]}}"
376 );
377
378 let msgs = [Message {
379 role: Role::User,
380 content: prompt,
381 parts: vec![],
382 metadata: MessageMetadata::default(),
383 }];
384
385 let mut output: ProbeQuestionsOutput = provider
386 .chat_typed_erased::<ProbeQuestionsOutput>(&msgs)
387 .await
388 .map_err(MemoryError::Llm)?;
389
390 output.questions.truncate(max_questions);
392
393 Ok(output.questions)
394}
395
396pub async fn answer_probe_questions(
402 provider: &AnyProvider,
403 summary: &str,
404 questions: &[ProbeQuestion],
405) -> Result<Vec<String>, MemoryError> {
406 let mut numbered = String::new();
407 for (i, q) in questions.iter().enumerate() {
408 use std::fmt::Write as _;
409 let _ = writeln!(numbered, "{}. {}", i + 1, q.question);
410 }
411
412 let prompt = format!(
413 "Given the following summary of a conversation, answer each question using ONLY \
414 information present in the summary. If the answer is not in the summary, respond \
415 with \"UNKNOWN\".\n\
416 \n\
417 Summary:\n{summary}\n\
418 \n\
419 Questions:\n{numbered}\n\
420 \n\
421 Respond in JSON with schema: {{\"answers\": [\"answer1\", \"answer2\", ...]}}"
422 );
423
424 let msgs = [Message {
425 role: Role::User,
426 content: prompt,
427 parts: vec![],
428 metadata: MessageMetadata::default(),
429 }];
430
431 let output: ProbeAnswersOutput = provider
432 .chat_typed_erased::<ProbeAnswersOutput>(&msgs)
433 .await
434 .map_err(MemoryError::Llm)?;
435
436 Ok(output.answers)
437}
438
439pub async fn validate_compaction(
453 provider: AnyProvider,
454 messages: Vec<Message>,
455 summary: String,
456 config: &CompactionProbeConfig,
457) -> Result<Option<CompactionProbeResult>, MemoryError> {
458 if !config.enabled {
459 return Ok(None);
460 }
461
462 let timeout = std::time::Duration::from_secs(config.timeout_secs);
463 let start = Instant::now();
464
465 let result = tokio::time::timeout(timeout, async {
466 run_probe(provider, messages, summary, config).await
467 })
468 .await;
469
470 match result {
471 Ok(inner) => inner,
472 Err(_elapsed) => {
473 tracing::warn!(
474 timeout_secs = config.timeout_secs,
475 "compaction probe timed out — proceeding with compaction"
476 );
477 Ok(None)
478 }
479 }
480 .map(|opt| {
481 opt.map(|mut r| {
482 r.duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
483 r
484 })
485 })
486}
487
488async fn run_probe(
489 provider: AnyProvider,
490 messages: Vec<Message>,
491 summary: String,
492 config: &CompactionProbeConfig,
493) -> Result<Option<CompactionProbeResult>, MemoryError> {
494 if summary.len() < 10 {
495 tracing::warn!(
496 len = summary.len(),
497 "compaction probe: summary too short — skipping probe"
498 );
499 return Ok(None);
500 }
501
502 let questions = generate_probe_questions(&provider, &messages, config.max_questions).await?;
503
504 if questions.len() < 2 {
505 tracing::debug!(
506 count = questions.len(),
507 "compaction probe: fewer than 2 questions generated — skipping probe"
508 );
509 return Ok(None);
510 }
511
512 if config.max_questions >= 4 {
514 use std::collections::HashSet;
515 let covered: HashSet<_> = questions.iter().map(|q| q.category).collect();
516 for cat in [
517 ProbeCategory::Recall,
518 ProbeCategory::Artifact,
519 ProbeCategory::Continuation,
520 ProbeCategory::Decision,
521 ] {
522 if !covered.contains(&cat) {
523 tracing::warn!(
524 category = ?cat,
525 "compaction probe: LLM did not generate questions for category"
526 );
527 }
528 }
529 }
530
531 let answers = answer_probe_questions(&provider, &summary, &questions).await?;
532
533 let (per_question_scores, _simple_avg) = score_answers(&questions, &answers);
534
535 let (category_scores, score) = compute_category_scores(
536 &questions,
537 &per_question_scores,
538 config.category_weights.as_ref(),
539 );
540
541 let verdict = if score >= config.threshold {
542 ProbeVerdict::Pass
543 } else if score >= config.hard_fail_threshold {
544 ProbeVerdict::SoftFail
545 } else {
546 ProbeVerdict::HardFail
547 };
548
549 let model = provider.name().to_owned();
550
551 Ok(Some(CompactionProbeResult {
552 score,
553 category_scores,
554 questions,
555 answers,
556 per_question_scores,
557 verdict,
558 threshold: config.threshold,
559 hard_fail_threshold: config.hard_fail_threshold,
560 model,
561 duration_ms: 0, }))
563}
564
565#[cfg(test)]
566mod tests {
567 use super::*;
568
569 #[test]
572 fn score_perfect_match() {
573 let q = vec![ProbeQuestion {
574 question: "What crate is used?".into(),
575 expected_answer: "thiserror".into(),
576 category: ProbeCategory::Recall,
577 }];
578 let a = vec!["thiserror".into()];
579 let (scores, avg) = score_answers(&q, &a);
580 assert_eq!(scores.len(), 1);
581 assert!((avg - 1.0).abs() < 0.01, "expected ~1.0, got {avg}");
582 }
583
584 #[test]
585 fn score_complete_mismatch() {
586 let q = vec![ProbeQuestion {
587 question: "What file was modified?".into(),
588 expected_answer: "src/auth.rs".into(),
589 ..Default::default()
590 }];
591 let a = vec!["definitely not in the summary".into()];
592 let (scores, avg) = score_answers(&q, &a);
593 assert_eq!(scores.len(), 1);
594 assert!(avg < 0.5, "expected low score, got {avg}");
596 }
597
598 #[test]
599 fn score_refusal_is_zero() {
600 let q = vec![ProbeQuestion {
601 question: "What was the decision?".into(),
602 expected_answer: "Use thiserror for typed errors".into(),
603 ..Default::default()
604 }];
605 for refusal in &[
606 "UNKNOWN",
607 "not mentioned",
608 "N/A",
609 "cannot determine",
610 "No information",
611 ] {
612 let a = vec![(*refusal).to_owned()];
613 let (_, avg) = score_answers(&q, &a);
614 assert!(avg < 0.01, "expected 0 for refusal '{refusal}', got {avg}");
615 }
616 }
617
618 #[test]
619 fn score_paraphrased_answer_above_half() {
620 let q = vec![ProbeQuestion {
623 question: "What error handling crate was chosen?".into(),
624 expected_answer: "Use thiserror for typed errors in library crates".into(),
625 ..Default::default()
626 }];
627 let a = vec!["thiserror was chosen for error types in library crates".into()];
628 let (_, avg) = score_answers(&q, &a);
629 assert!(avg > 0.5, "expected >0.5 for paraphrase, got {avg}");
630 }
631
632 #[test]
633 fn score_empty_strings() {
634 let q = vec![ProbeQuestion {
635 question: "What?".into(),
636 expected_answer: String::new(),
637 ..Default::default()
638 }];
639 let a = vec![String::new()];
640 let (scores, avg) = score_answers(&q, &a);
641 assert_eq!(scores.len(), 1);
642 assert!(
644 (avg - 1.0).abs() < 0.01,
645 "expected 1.0 for empty vs empty, got {avg}"
646 );
647 }
648
649 #[test]
650 fn score_empty_questions_list() {
651 let (scores, avg) = score_answers(&[], &[]);
652 assert!(scores.is_empty());
653 assert!((avg - 0.0).abs() < 0.01);
654 }
655
656 #[test]
657 fn score_file_path_exact() {
658 let q = vec![ProbeQuestion {
659 question: "Which file was modified?".into(),
660 expected_answer: "crates/zeph-memory/src/compaction_probe.rs".into(),
661 ..Default::default()
662 }];
663 let a = vec!["The file crates/zeph-memory/src/compaction_probe.rs was modified.".into()];
664 let (_, avg) = score_answers(&q, &a);
665 assert!(
667 avg > 0.8,
668 "expected high score for file path match, got {avg}"
669 );
670 }
671
672 #[test]
673 fn score_unicode_input() {
674 let q = vec![ProbeQuestion {
675 question: "Что было изменено?".into(),
676 expected_answer: "файл config.toml".into(),
677 ..Default::default()
678 }];
679 let a = vec!["config.toml был изменён".into()];
680 let (scores, _) = score_answers(&q, &a);
682 assert_eq!(scores.len(), 1);
683 }
684
685 #[test]
688 fn verdict_thresholds() {
689 let config = CompactionProbeConfig::default();
690
691 let score = 0.7_f32;
693 let verdict = if score >= config.threshold {
694 ProbeVerdict::Pass
695 } else if score >= config.hard_fail_threshold {
696 ProbeVerdict::SoftFail
697 } else {
698 ProbeVerdict::HardFail
699 };
700 assert_eq!(verdict, ProbeVerdict::Pass);
701
702 let score = 0.5_f32;
704 let verdict = if score >= config.threshold {
705 ProbeVerdict::Pass
706 } else if score >= config.hard_fail_threshold {
707 ProbeVerdict::SoftFail
708 } else {
709 ProbeVerdict::HardFail
710 };
711 assert_eq!(verdict, ProbeVerdict::SoftFail);
712
713 let score = 0.2_f32;
715 let verdict = if score >= config.threshold {
716 ProbeVerdict::Pass
717 } else if score >= config.hard_fail_threshold {
718 ProbeVerdict::SoftFail
719 } else {
720 ProbeVerdict::HardFail
721 };
722 assert_eq!(verdict, ProbeVerdict::HardFail);
723 }
724
725 #[test]
728 fn config_defaults() {
729 let c = CompactionProbeConfig::default();
730 assert!(!c.enabled);
731 assert!(c.probe_provider.is_none());
732 assert!((c.threshold - 0.6).abs() < 0.001);
733 assert!((c.hard_fail_threshold - 0.35).abs() < 0.001);
734 assert_eq!(c.max_questions, 5);
735 assert_eq!(c.timeout_secs, 15);
736 assert!(c.category_weights.is_none());
737 }
738
739 #[test]
742 fn config_serde_round_trip() {
743 let original = CompactionProbeConfig {
744 enabled: true,
745 probe_provider: Some(zeph_config::ProviderName::new("fast")),
746 threshold: 0.65,
747 hard_fail_threshold: 0.4,
748 max_questions: 5,
749 timeout_secs: 20,
750 category_weights: None,
751 };
752 let json = serde_json::to_string(&original).expect("serialize");
753 let restored: CompactionProbeConfig = serde_json::from_str(&json).expect("deserialize");
754 assert!(restored.enabled);
755 assert_eq!(
756 restored
757 .probe_provider
758 .as_ref()
759 .map(zeph_config::ProviderName::as_str),
760 Some("fast")
761 );
762 assert!((restored.threshold - 0.65).abs() < 0.001);
763 }
764
765 #[test]
766 fn probe_result_serde_round_trip() {
767 let result = CompactionProbeResult {
768 score: 0.75,
769 category_scores: vec![CategoryScore {
770 category: ProbeCategory::Recall,
771 score: 0.75,
772 probes_run: 1,
773 }],
774 questions: vec![ProbeQuestion {
775 question: "What?".into(),
776 expected_answer: "thiserror".into(),
777 category: ProbeCategory::Recall,
778 }],
779 answers: vec!["thiserror".into()],
780 per_question_scores: vec![1.0],
781 verdict: ProbeVerdict::Pass,
782 threshold: 0.6,
783 hard_fail_threshold: 0.35,
784 model: "haiku".into(),
785 duration_ms: 1234,
786 };
787 let json = serde_json::to_string(&result).expect("serialize");
788 let restored: CompactionProbeResult = serde_json::from_str(&json).expect("deserialize");
789 assert!((restored.score - 0.75).abs() < 0.001);
790 assert_eq!(restored.verdict, ProbeVerdict::Pass);
791 assert_eq!(restored.category_scores.len(), 1);
792 }
793
794 #[test]
795 fn probe_result_backward_compat_no_category_scores() {
796 let json = r#"{"score":0.75,"questions":[],"answers":[],"per_question_scores":[],"verdict":"Pass","threshold":0.6,"hard_fail_threshold":0.35,"model":"haiku","duration_ms":0}"#;
798 let restored: CompactionProbeResult = serde_json::from_str(json).expect("deserialize");
799 assert!(restored.category_scores.is_empty());
800 }
801
802 #[test]
805 fn score_fewer_answers_than_questions() {
806 let questions = vec![
807 ProbeQuestion {
808 question: "What crate?".into(),
809 expected_answer: "thiserror".into(),
810 ..Default::default()
811 },
812 ProbeQuestion {
813 question: "What file?".into(),
814 expected_answer: "src/lib.rs".into(),
815 ..Default::default()
816 },
817 ProbeQuestion {
818 question: "What decision?".into(),
819 expected_answer: "use async traits".into(),
820 ..Default::default()
821 },
822 ];
823 let answers = vec!["thiserror".into()];
825 let (scores, avg) = score_answers(&questions, &answers);
826 assert_eq!(scores.len(), 3);
828 assert!(
830 (scores[0] - 1.0).abs() < 0.01,
831 "first score should be ~1.0, got {}",
832 scores[0]
833 );
834 assert!(
836 scores[1] < 0.5,
837 "second score should be low for missing answer, got {}",
838 scores[1]
839 );
840 assert!(
841 scores[2] < 0.5,
842 "third score should be low for missing answer, got {}",
843 scores[2]
844 );
845 assert!(
847 avg < 0.5,
848 "average should be below 0.5 with 2 missing answers, got {avg}"
849 );
850 }
851
852 #[test]
855 fn verdict_boundary_at_threshold() {
856 let config = CompactionProbeConfig::default();
857
858 let score = config.threshold;
860 let verdict = if score >= config.threshold {
861 ProbeVerdict::Pass
862 } else if score >= config.hard_fail_threshold {
863 ProbeVerdict::SoftFail
864 } else {
865 ProbeVerdict::HardFail
866 };
867 assert_eq!(verdict, ProbeVerdict::Pass);
868
869 let score = config.threshold - f32::EPSILON;
871 let verdict = if score >= config.threshold {
872 ProbeVerdict::Pass
873 } else if score >= config.hard_fail_threshold {
874 ProbeVerdict::SoftFail
875 } else {
876 ProbeVerdict::HardFail
877 };
878 assert_eq!(verdict, ProbeVerdict::SoftFail);
879
880 let score = config.hard_fail_threshold;
882 let verdict = if score >= config.threshold {
883 ProbeVerdict::Pass
884 } else if score >= config.hard_fail_threshold {
885 ProbeVerdict::SoftFail
886 } else {
887 ProbeVerdict::HardFail
888 };
889 assert_eq!(verdict, ProbeVerdict::SoftFail);
890
891 let score = config.hard_fail_threshold - f32::EPSILON;
893 let verdict = if score >= config.threshold {
894 ProbeVerdict::Pass
895 } else if score >= config.hard_fail_threshold {
896 ProbeVerdict::SoftFail
897 } else {
898 ProbeVerdict::HardFail
899 };
900 assert_eq!(verdict, ProbeVerdict::HardFail);
901 }
902
903 #[test]
906 fn config_partial_json_uses_defaults() {
907 let json = r#"{"enabled": true}"#;
909 let c: CompactionProbeConfig =
910 serde_json::from_str(json).expect("deserialize partial json");
911 assert!(c.enabled);
912 assert!(c.probe_provider.is_none());
913 assert!((c.threshold - 0.6).abs() < 0.001);
914 assert!((c.hard_fail_threshold - 0.35).abs() < 0.001);
915 assert_eq!(c.max_questions, 5);
916 assert_eq!(c.timeout_secs, 15);
917 }
918
919 #[test]
920 fn config_empty_json_uses_all_defaults() {
921 let c: CompactionProbeConfig = serde_json::from_str("{}").expect("deserialize empty json");
922 assert!(!c.enabled);
923 assert!(c.probe_provider.is_none());
924 }
925
926 #[test]
929 fn probe_category_serde_lowercase() {
930 assert_eq!(
931 serde_json::to_string(&ProbeCategory::Recall).unwrap(),
932 r#""recall""#
933 );
934 assert_eq!(
935 serde_json::to_string(&ProbeCategory::Artifact).unwrap(),
936 r#""artifact""#
937 );
938 assert_eq!(
939 serde_json::to_string(&ProbeCategory::Continuation).unwrap(),
940 r#""continuation""#
941 );
942 assert_eq!(
943 serde_json::to_string(&ProbeCategory::Decision).unwrap(),
944 r#""decision""#
945 );
946 let cat: ProbeCategory = serde_json::from_str(r#""recall""#).unwrap();
947 assert_eq!(cat, ProbeCategory::Recall);
948 }
949
950 #[test]
953 fn category_weights_toml_round_trip() {
954 let toml_str = r#"
955enabled = true
956probe_provider = "fast"
957threshold = 0.6
958hard_fail_threshold = 0.35
959max_questions = 5
960timeout_secs = 15
961[category_weights]
962recall = 1.5
963artifact = 1.0
964continuation = 1.0
965decision = 0.8
966"#;
967 let c: CompactionProbeConfig = toml::from_str(toml_str).expect("deserialize toml");
968 let weights = c.category_weights.as_ref().unwrap();
969 assert!((weights[&ProbeCategory::Recall] - 1.5).abs() < 0.001);
970 assert!((weights[&ProbeCategory::Decision] - 0.8).abs() < 0.001);
971 }
972
973 #[test]
976 fn category_scores_equal_weights() {
977 let questions = vec![
978 ProbeQuestion {
979 question: "Q1".into(),
980 expected_answer: "A1".into(),
981 category: ProbeCategory::Recall,
982 },
983 ProbeQuestion {
984 question: "Q2".into(),
985 expected_answer: "A2".into(),
986 category: ProbeCategory::Artifact,
987 },
988 ];
989 let scores = [1.0_f32, 0.0_f32];
990 let (cats, overall) = compute_category_scores(&questions, &scores, None);
991 assert_eq!(cats.len(), 2);
992 assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
994 }
995
996 #[test]
997 fn category_scores_missing_category_excluded() {
998 let questions = vec![
1000 ProbeQuestion {
1001 question: "Q1".into(),
1002 expected_answer: "A1".into(),
1003 category: ProbeCategory::Recall,
1004 },
1005 ProbeQuestion {
1006 question: "Q2".into(),
1007 expected_answer: "A2".into(),
1008 category: ProbeCategory::Decision,
1009 },
1010 ];
1011 let scores = [1.0_f32, 0.6_f32];
1012 let (cats, _overall) = compute_category_scores(&questions, &scores, None);
1013 assert_eq!(cats.len(), 2, "only categories with questions present");
1014 let categories: Vec<_> = cats.iter().map(|c| c.category).collect();
1015 assert!(!categories.contains(&ProbeCategory::Artifact));
1016 assert!(!categories.contains(&ProbeCategory::Continuation));
1017 }
1018
1019 #[test]
1020 fn category_scores_custom_weights() {
1021 let questions = vec![
1022 ProbeQuestion {
1023 question: "Q1".into(),
1024 expected_answer: "A1".into(),
1025 category: ProbeCategory::Recall,
1026 },
1027 ProbeQuestion {
1028 question: "Q2".into(),
1029 expected_answer: "A2".into(),
1030 category: ProbeCategory::Decision,
1031 },
1032 ];
1033 let scores = [1.0_f32, 0.0_f32];
1034 let mut weights = HashMap::new();
1035 weights.insert(ProbeCategory::Recall, 2.0_f32);
1036 weights.insert(ProbeCategory::Decision, 1.0_f32);
1037 let (_, overall) = compute_category_scores(&questions, &scores, Some(&weights));
1038 assert!(
1040 (overall - 2.0 / 3.0).abs() < 0.001,
1041 "expected ~0.667, got {overall}"
1042 );
1043 }
1044
1045 #[test]
1046 fn category_scores_all_zero_weights_fallback() {
1047 let questions = vec![
1048 ProbeQuestion {
1049 question: "Q1".into(),
1050 expected_answer: "A1".into(),
1051 category: ProbeCategory::Recall,
1052 },
1053 ProbeQuestion {
1054 question: "Q2".into(),
1055 expected_answer: "A2".into(),
1056 category: ProbeCategory::Artifact,
1057 },
1058 ];
1059 let scores = [1.0_f32, 0.0_f32];
1060 let mut weights = HashMap::new();
1061 weights.insert(ProbeCategory::Recall, 0.0_f32);
1062 weights.insert(ProbeCategory::Artifact, 0.0_f32);
1063 let (_, overall) = compute_category_scores(&questions, &scores, Some(&weights));
1064 assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
1066 }
1067
1068 #[test]
1069 fn category_scores_empty_questions() {
1070 let (cats, overall) = compute_category_scores(&[], &[], None);
1071 assert!(cats.is_empty());
1072 assert!((overall - 0.0).abs() < 0.001);
1073 }
1074
1075 #[test]
1076 fn category_scores_multi_probe_single_category_averages() {
1077 let questions = vec![
1079 ProbeQuestion {
1080 question: "Q1".into(),
1081 expected_answer: "A1".into(),
1082 category: ProbeCategory::Recall,
1083 },
1084 ProbeQuestion {
1085 question: "Q2".into(),
1086 expected_answer: "A2".into(),
1087 category: ProbeCategory::Recall,
1088 },
1089 ProbeQuestion {
1090 question: "Q3".into(),
1091 expected_answer: "A3".into(),
1092 category: ProbeCategory::Recall,
1093 },
1094 ];
1095 let scores = [1.0_f32, 0.0_f32, 0.5_f32];
1096 let (cats, overall) = compute_category_scores(&questions, &scores, None);
1097 assert_eq!(cats.len(), 1, "only one category present");
1098 assert_eq!(cats[0].category, ProbeCategory::Recall);
1099 assert_eq!(cats[0].probes_run, 3);
1100 assert!(
1101 (cats[0].score - 0.5).abs() < 0.001,
1102 "cat score={}",
1103 cats[0].score
1104 );
1105 assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
1107 }
1108
1109 #[test]
1110 fn probe_question_serde_default_category() {
1111 let json = r#"{"question":"What file?","expected_answer":"src/lib.rs"}"#;
1113 let q: ProbeQuestion = serde_json::from_str(json).expect("deserialize");
1114 assert_eq!(q.category, ProbeCategory::Recall);
1115 assert_eq!(q.question, "What file?");
1116 assert_eq!(q.expected_answer, "src/lib.rs");
1117 }
1118
1119 #[test]
1120 fn probe_question_serde_all_categories_round_trip() {
1121 for cat in [
1122 ProbeCategory::Recall,
1123 ProbeCategory::Artifact,
1124 ProbeCategory::Continuation,
1125 ProbeCategory::Decision,
1126 ] {
1127 let q = ProbeQuestion {
1128 question: "test?".into(),
1129 expected_answer: "answer".into(),
1130 category: cat,
1131 };
1132 let json = serde_json::to_string(&q).expect("serialize");
1133 let restored: ProbeQuestion = serde_json::from_str(&json).expect("deserialize");
1134 assert_eq!(restored.category, cat);
1135 }
1136 }
1137}