Skip to main content

zeph_memory/
compaction_probe.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Compaction probe: validates summary quality before committing it to the context.
5//!
6//! Generates factual questions from the messages being compacted, then answers them
7//! using only the summary text, and scores the answers against expected values.
8//! Returns a [`CompactionProbeResult`] that the caller uses to decide whether to
9//! commit or reject the summary.
10
11use std::collections::HashMap;
12use std::time::Instant;
13
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16use zeph_llm::any::AnyProvider;
17use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, MessagePart, Role};
18
19use crate::error::MemoryError;
20
21// --- Data structures ---
22
23use zeph_config::memory::{CompactionProbeConfig, ProbeCategory};
24
25/// Per-category scoring breakdown from a compaction probe run.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct CategoryScore {
28    pub category: ProbeCategory,
29    /// Average score of all questions in this category, in [0.0, 1.0].
30    pub score: f32,
31    /// Number of questions generated for this category.
32    pub probes_run: u32,
33}
34
35/// A single factual question with the expected answer.
36#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
37pub struct ProbeQuestion {
38    /// Factual question about the compacted messages.
39    pub question: String,
40    /// Expected correct answer extractable from the original messages.
41    pub expected_answer: String,
42    /// Functional category of this question.
43    #[serde(default = "default_probe_category")]
44    pub category: ProbeCategory,
45}
46
47fn default_probe_category() -> ProbeCategory {
48    ProbeCategory::Recall
49}
50
51impl Default for ProbeQuestion {
52    fn default() -> Self {
53        Self {
54            question: String::new(),
55            expected_answer: String::new(),
56            category: ProbeCategory::Recall,
57        }
58    }
59}
60
61/// Three-tier verdict for compaction probe quality.
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
63pub enum ProbeVerdict {
64    /// Score >= `threshold`: summary preserves enough context. Proceed.
65    Pass,
66    /// Score in [`hard_fail_threshold`, `threshold`): summary is borderline.
67    /// Proceed with compaction but log a warning.
68    SoftFail,
69    /// Score < `hard_fail_threshold`: summary lost critical facts. Block compaction.
70    HardFail,
71    /// Transport/timeout failure — no quality score produced.
72    Error,
73}
74
75/// Full result of a compaction probe run.
76#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct CompactionProbeResult {
78    /// Overall score in [0.0, 1.0].
79    pub score: f32,
80    /// Per-category breakdown. Categories with 0 questions are omitted.
81    #[serde(default)]
82    pub category_scores: Vec<CategoryScore>,
83    /// Per-question breakdown.
84    pub questions: Vec<ProbeQuestion>,
85    /// LLM answers to the questions (positionally aligned with `questions`).
86    pub answers: Vec<String>,
87    /// Per-question similarity scores.
88    pub per_question_scores: Vec<f32>,
89    pub verdict: ProbeVerdict,
90    /// Pass threshold used for this run.
91    pub threshold: f32,
92    /// Hard-fail threshold used for this run.
93    pub hard_fail_threshold: f32,
94    /// Model name used for the probe.
95    pub model: String,
96    /// Wall-clock duration in milliseconds.
97    pub duration_ms: u64,
98}
99
100// --- Structured LLM output types ---
101
102#[derive(Debug, Deserialize, JsonSchema)]
103struct ProbeQuestionsOutput {
104    questions: Vec<ProbeQuestion>,
105}
106
107// --- Category scoring ---
108
109/// Group per-question scores by category and compute per-category averages.
110///
111/// Categories with no questions are excluded from the returned list.
112#[must_use]
113fn compute_category_scores(
114    questions: &[ProbeQuestion],
115    per_question_scores: &[f32],
116    category_weights: Option<&HashMap<ProbeCategory, f32>>,
117) -> (Vec<CategoryScore>, f32) {
118    // Group scores by category.
119    let mut by_cat: HashMap<ProbeCategory, Vec<f32>> = HashMap::new();
120    for (q, &s) in questions.iter().zip(per_question_scores.iter()) {
121        by_cat.entry(q.category).or_default().push(s);
122    }
123
124    #[allow(clippy::cast_precision_loss)]
125    let category_scores: Vec<CategoryScore> = by_cat
126        .into_iter()
127        .map(|(category, scores)| {
128            let avg = scores.iter().sum::<f32>() / scores.len() as f32;
129            CategoryScore {
130                category,
131                score: avg,
132                #[allow(clippy::cast_possible_truncation)]
133                probes_run: scores.len() as u32,
134            }
135        })
136        .collect();
137
138    if category_scores.is_empty() {
139        return (category_scores, 0.0);
140    }
141
142    // Compute weighted overall score. Default: equal weights.
143    let mut weighted_sum = 0.0_f32;
144    let mut weight_total = 0.0_f32;
145    for cs in &category_scores {
146        let raw_w = category_weights
147            .and_then(|m| m.get(&cs.category).copied())
148            .unwrap_or(1.0);
149        if raw_w < 0.0 {
150            tracing::warn!(
151                category = ?cs.category,
152                weight = raw_w,
153                "category_weights contains a negative value — treating as 0.0 (category excluded from scoring)"
154            );
155        }
156        let w = raw_w.max(0.0);
157        weighted_sum += cs.score * w;
158        weight_total += w;
159    }
160
161    let overall = if weight_total > 0.0 {
162        weighted_sum / weight_total
163    } else {
164        // All weights are 0 — fall back to equal weighting.
165        #[allow(clippy::cast_precision_loss)]
166        let n = category_scores.len() as f32;
167        category_scores.iter().map(|cs| cs.score).sum::<f32>() / n
168    };
169
170    (category_scores, overall)
171}
172
173#[derive(Debug, Deserialize, JsonSchema)]
174struct ProbeAnswersOutput {
175    answers: Vec<String>,
176}
177
178// --- Scoring ---
179
180/// Refusal indicators: if the actual answer contains any of these, score it 0.0.
181const REFUSAL_PATTERNS: &[&str] = &[
182    "unknown",
183    "not mentioned",
184    "not found",
185    "n/a",
186    "cannot determine",
187    "no information",
188    "not provided",
189    "not specified",
190    "not stated",
191    "not available",
192];
193
194fn is_refusal(text: &str) -> bool {
195    let lower = text.to_lowercase();
196    REFUSAL_PATTERNS.iter().any(|p| lower.contains(p))
197}
198
199/// Normalize a string: lowercase, split on non-alphanumeric chars, keep tokens >= 3 chars.
200fn normalize_tokens(text: &str) -> Vec<String> {
201    text.to_lowercase()
202        .split(|c: char| !c.is_alphanumeric())
203        .filter(|t| t.len() >= 3)
204        .map(String::from)
205        .collect()
206}
207
208fn jaccard(a: &[String], b: &[String]) -> f32 {
209    if a.is_empty() && b.is_empty() {
210        return 1.0;
211    }
212    let set_a: std::collections::HashSet<&str> = a.iter().map(String::as_str).collect();
213    let set_b: std::collections::HashSet<&str> = b.iter().map(String::as_str).collect();
214    let intersection = set_a.intersection(&set_b).count();
215    let union = set_a.union(&set_b).count();
216    if union == 0 {
217        return 0.0;
218    }
219    #[allow(clippy::cast_precision_loss)]
220    {
221        intersection as f32 / union as f32
222    }
223}
224
225/// Score a single (expected, actual) answer pair using token-set-ratio.
226fn score_pair(expected: &str, actual: &str) -> f32 {
227    if is_refusal(actual) {
228        return 0.0;
229    }
230
231    let tokens_e = normalize_tokens(expected);
232    let tokens_a = normalize_tokens(actual);
233
234    // Substring boost: if all expected tokens appear in actual, it's an exact match.
235    if !tokens_e.is_empty() {
236        let set_e: std::collections::HashSet<&str> = tokens_e.iter().map(String::as_str).collect();
237        let set_a: std::collections::HashSet<&str> = tokens_a.iter().map(String::as_str).collect();
238        if set_e.is_subset(&set_a) {
239            return 1.0;
240        }
241    }
242
243    // Token-set-ratio: max of three Jaccard variants.
244    let j_full = jaccard(&tokens_e, &tokens_a);
245
246    // Intersection with each set individually (handles subset relationships).
247    let set_e: std::collections::HashSet<&str> = tokens_e.iter().map(String::as_str).collect();
248    let set_a: std::collections::HashSet<&str> = tokens_a.iter().map(String::as_str).collect();
249    let intersection: Vec<String> = set_e
250        .intersection(&set_a)
251        .map(|s| (*s).to_owned())
252        .collect();
253
254    #[allow(clippy::cast_precision_loss)]
255    let j_e = if tokens_e.is_empty() {
256        0.0_f32
257    } else {
258        intersection.len() as f32 / tokens_e.len() as f32
259    };
260    #[allow(clippy::cast_precision_loss)]
261    let j_a = if tokens_a.is_empty() {
262        0.0_f32
263    } else {
264        intersection.len() as f32 / tokens_a.len() as f32
265    };
266
267    j_full.max(j_e).max(j_a)
268}
269
270/// Score answers against expected values using token-set-ratio similarity.
271///
272/// Returns `(per_question_scores, overall_average)`.
273#[must_use]
274pub fn score_answers(questions: &[ProbeQuestion], answers: &[String]) -> (Vec<f32>, f32) {
275    if questions.is_empty() {
276        return (vec![], 0.0);
277    }
278    let scores: Vec<f32> = questions
279        .iter()
280        .zip(answers.iter().chain(std::iter::repeat(&String::new())))
281        .map(|(q, a)| score_pair(&q.expected_answer, a))
282        .collect();
283    #[allow(clippy::cast_precision_loss)]
284    let avg = if scores.is_empty() {
285        0.0
286    } else {
287        scores.iter().sum::<f32>() / scores.len() as f32
288    };
289    (scores, avg)
290}
291
292// --- LLM calls ---
293
294/// Truncate tool-result bodies to 500 chars to avoid flooding the probe with raw output.
295fn truncate_tool_bodies(messages: &[Message]) -> Vec<Message> {
296    messages
297        .iter()
298        .map(|m| {
299            let mut msg = m.clone();
300            for part in &mut msg.parts {
301                if let MessagePart::ToolOutput { body, .. } = part {
302                    if body.len() <= 500 {
303                        continue;
304                    }
305                    body.truncate(500);
306                    body.push('\u{2026}');
307                }
308            }
309            msg.rebuild_content();
310            msg
311        })
312        .collect()
313}
314
315/// Generate factual probe questions from the messages being compacted.
316///
317/// Uses a single LLM call with structured output. Tool-result bodies are
318/// truncated to 500 chars to focus on decisions and outcomes rather than raw tool output.
319///
320/// # Errors
321///
322/// Returns `MemoryError::Llm` if the LLM call fails.
323#[cfg_attr(
324    feature = "profiling",
325    tracing::instrument(name = "memory.compaction_probe", skip_all)
326)]
327pub async fn generate_probe_questions(
328    provider: &AnyProvider,
329    messages: &[Message],
330    max_questions: usize,
331) -> Result<Vec<ProbeQuestion>, MemoryError> {
332    let truncated = truncate_tool_bodies(messages);
333
334    let mut history = String::new();
335    for msg in &truncated {
336        let role = match msg.role {
337            Role::User => "user",
338            Role::Assistant => "assistant",
339            Role::System => "system",
340        };
341        history.push_str(role);
342        history.push_str(": ");
343        history.push_str(&msg.content);
344        history.push('\n');
345    }
346
347    let prompt = format!(
348        "Given the following conversation excerpt, generate {max_questions} factual questions \
349         that test whether a summary preserves the most important concrete details.\n\
350         \n\
351         You MUST generate at least one question per category when max_questions >= 4. \
352         If the conversation lacks information for a category, generate a question noting that absence.\n\
353         \n\
354         Categories:\n\
355         - recall: Specific facts that survived (file paths, function names, values). \
356           Example: \"What file was modified?\"\n\
357         - artifact: Which files/tools/URLs the agent used. \
358           Example: \"Which tool was executed?\"\n\
359         - continuation: Next steps, blockers, open questions. \
360           Example: \"What is the next step?\"\n\
361         - decision: Past reasoning traces (why X over Y, trade-offs). \
362           Example: \"Why was X chosen over Y?\"\n\
363         \n\
364         Do NOT generate questions about:\n\
365         - Raw tool output content (compiler warnings, test output line numbers)\n\
366         - Intermediate debugging steps that were superseded\n\
367         - Opinions or reasoning that cannot be verified\n\
368         \n\
369         Each question must have a single unambiguous expected answer extractable from the text.\n\
370         \n\
371         Conversation:\n{history}\n\
372         \n\
373         Respond in JSON with schema: {{\"questions\": [{{\"question\": \"...\", \
374         \"expected_answer\": \"...\", \"category\": \"recall|artifact|continuation|decision\"}}]}}"
375    );
376
377    let msgs = [Message {
378        role: Role::User,
379        content: prompt,
380        parts: vec![],
381        metadata: MessageMetadata::default(),
382    }];
383
384    let mut output: ProbeQuestionsOutput = provider
385        .chat_typed_erased::<ProbeQuestionsOutput>(&msgs)
386        .await
387        .map_err(MemoryError::Llm)?;
388
389    // Cap the list to max_questions: a misbehaving LLM could return more.
390    output.questions.truncate(max_questions);
391
392    Ok(output.questions)
393}
394
395/// Answer probe questions using only the compaction summary as context.
396///
397/// # Errors
398///
399/// Returns `MemoryError::Llm` if the LLM call fails.
400pub async fn answer_probe_questions(
401    provider: &AnyProvider,
402    summary: &str,
403    questions: &[ProbeQuestion],
404) -> Result<Vec<String>, MemoryError> {
405    let mut numbered = String::new();
406    for (i, q) in questions.iter().enumerate() {
407        use std::fmt::Write as _;
408        let _ = writeln!(numbered, "{}. {}", i + 1, q.question);
409    }
410
411    let prompt = format!(
412        "Given the following summary of a conversation, answer each question using ONLY \
413         information present in the summary. If the answer is not in the summary, respond \
414         with \"UNKNOWN\".\n\
415         \n\
416         Summary:\n{summary}\n\
417         \n\
418         Questions:\n{numbered}\n\
419         \n\
420         Respond in JSON with schema: {{\"answers\": [\"answer1\", \"answer2\", ...]}}"
421    );
422
423    let msgs = [Message {
424        role: Role::User,
425        content: prompt,
426        parts: vec![],
427        metadata: MessageMetadata::default(),
428    }];
429
430    let output: ProbeAnswersOutput = provider
431        .chat_typed_erased::<ProbeAnswersOutput>(&msgs)
432        .await
433        .map_err(MemoryError::Llm)?;
434
435    Ok(output.answers)
436}
437
438/// Run the compaction probe: generate questions, answer them from the summary, score results.
439///
440/// Returns `Ok(None)` when:
441/// - Probe is disabled (`config.enabled = false`)
442/// - The probe times out
443/// - Fewer than 2 questions are generated (insufficient statistical power)
444///
445/// The caller treats `None` as "no opinion" and proceeds with compaction.
446///
447/// # Errors
448///
449/// Returns `MemoryError` if an LLM call fails. Callers should treat this as non-fatal
450/// and proceed with compaction.
451pub async fn validate_compaction(
452    provider: AnyProvider,
453    messages: Vec<Message>,
454    summary: String,
455    config: &CompactionProbeConfig,
456) -> Result<Option<CompactionProbeResult>, MemoryError> {
457    if !config.enabled {
458        return Ok(None);
459    }
460
461    let timeout = std::time::Duration::from_secs(config.timeout_secs);
462    let start = Instant::now();
463
464    let result = tokio::time::timeout(timeout, async {
465        run_probe(provider, messages, summary, config).await
466    })
467    .await;
468
469    match result {
470        Ok(inner) => inner,
471        Err(_elapsed) => {
472            tracing::warn!(
473                timeout_secs = config.timeout_secs,
474                "compaction probe timed out — proceeding with compaction"
475            );
476            Ok(None)
477        }
478    }
479    .map(|opt| {
480        opt.map(|mut r| {
481            r.duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
482            r
483        })
484    })
485}
486
487async fn run_probe(
488    provider: AnyProvider,
489    messages: Vec<Message>,
490    summary: String,
491    config: &CompactionProbeConfig,
492) -> Result<Option<CompactionProbeResult>, MemoryError> {
493    if summary.len() < 10 {
494        tracing::warn!(
495            len = summary.len(),
496            "compaction probe: summary too short — skipping probe"
497        );
498        return Ok(None);
499    }
500
501    let questions = generate_probe_questions(&provider, &messages, config.max_questions).await?;
502
503    if questions.len() < 2 {
504        tracing::debug!(
505            count = questions.len(),
506            "compaction probe: fewer than 2 questions generated — skipping probe"
507        );
508        return Ok(None);
509    }
510
511    // Warn if any category is missing when we expected full coverage.
512    if config.max_questions >= 4 {
513        use std::collections::HashSet;
514        let covered: HashSet<_> = questions.iter().map(|q| q.category).collect();
515        for cat in [
516            ProbeCategory::Recall,
517            ProbeCategory::Artifact,
518            ProbeCategory::Continuation,
519            ProbeCategory::Decision,
520        ] {
521            if !covered.contains(&cat) {
522                tracing::warn!(
523                    category = ?cat,
524                    "compaction probe: LLM did not generate questions for category"
525                );
526            }
527        }
528    }
529
530    let answers = answer_probe_questions(&provider, &summary, &questions).await?;
531
532    let (per_question_scores, _simple_avg) = score_answers(&questions, &answers);
533
534    let (category_scores, score) = compute_category_scores(
535        &questions,
536        &per_question_scores,
537        config.category_weights.as_ref(),
538    );
539
540    let verdict = if score >= config.threshold {
541        ProbeVerdict::Pass
542    } else if score >= config.hard_fail_threshold {
543        ProbeVerdict::SoftFail
544    } else {
545        ProbeVerdict::HardFail
546    };
547
548    let model = provider.name().to_owned();
549
550    Ok(Some(CompactionProbeResult {
551        score,
552        category_scores,
553        questions,
554        answers,
555        per_question_scores,
556        verdict,
557        threshold: config.threshold,
558        hard_fail_threshold: config.hard_fail_threshold,
559        model,
560        duration_ms: 0, // filled in by validate_compaction
561    }))
562}
563
564#[cfg(test)]
565mod tests {
566    use super::*;
567
568    // --- score_answers tests ---
569
570    #[test]
571    fn score_perfect_match() {
572        let q = vec![ProbeQuestion {
573            question: "What crate is used?".into(),
574            expected_answer: "thiserror".into(),
575            category: ProbeCategory::Recall,
576        }];
577        let a = vec!["thiserror".into()];
578        let (scores, avg) = score_answers(&q, &a);
579        assert_eq!(scores.len(), 1);
580        assert!((avg - 1.0).abs() < 0.01, "expected ~1.0, got {avg}");
581    }
582
583    #[test]
584    fn score_complete_mismatch() {
585        let q = vec![ProbeQuestion {
586            question: "What file was modified?".into(),
587            expected_answer: "src/auth.rs".into(),
588            ..Default::default()
589        }];
590        let a = vec!["definitely not in the summary".into()];
591        let (scores, avg) = score_answers(&q, &a);
592        assert_eq!(scores.len(), 1);
593        // Very low overlap expected.
594        assert!(avg < 0.5, "expected low score, got {avg}");
595    }
596
597    #[test]
598    fn score_refusal_is_zero() {
599        let q = vec![ProbeQuestion {
600            question: "What was the decision?".into(),
601            expected_answer: "Use thiserror for typed errors".into(),
602            ..Default::default()
603        }];
604        for refusal in &[
605            "UNKNOWN",
606            "not mentioned",
607            "N/A",
608            "cannot determine",
609            "No information",
610        ] {
611            let a = vec![(*refusal).to_owned()];
612            let (_, avg) = score_answers(&q, &a);
613            assert!(avg < 0.01, "expected 0 for refusal '{refusal}', got {avg}");
614        }
615    }
616
617    #[test]
618    fn score_paraphrased_answer_above_half() {
619        // "thiserror was chosen for error types" vs "Use thiserror for typed errors"
620        // Shared tokens: "thiserror", "error" (and maybe "for"/"types"/"typed" with >=3 chars)
621        let q = vec![ProbeQuestion {
622            question: "What error handling crate was chosen?".into(),
623            expected_answer: "Use thiserror for typed errors in library crates".into(),
624            ..Default::default()
625        }];
626        let a = vec!["thiserror was chosen for error types in library crates".into()];
627        let (_, avg) = score_answers(&q, &a);
628        assert!(avg > 0.5, "expected >0.5 for paraphrase, got {avg}");
629    }
630
631    #[test]
632    fn score_empty_strings() {
633        let q = vec![ProbeQuestion {
634            question: "What?".into(),
635            expected_answer: String::new(),
636            ..Default::default()
637        }];
638        let a = vec![String::new()];
639        let (scores, avg) = score_answers(&q, &a);
640        assert_eq!(scores.len(), 1);
641        // Both empty — jaccard of two empty sets returns 1.0 (exact match).
642        assert!(
643            (avg - 1.0).abs() < 0.01,
644            "expected 1.0 for empty vs empty, got {avg}"
645        );
646    }
647
648    #[test]
649    fn score_empty_questions_list() {
650        let (scores, avg) = score_answers(&[], &[]);
651        assert!(scores.is_empty());
652        assert!((avg - 0.0).abs() < 0.01);
653    }
654
655    #[test]
656    fn score_file_path_exact() {
657        let q = vec![ProbeQuestion {
658            question: "Which file was modified?".into(),
659            expected_answer: "crates/zeph-memory/src/compaction_probe.rs".into(),
660            ..Default::default()
661        }];
662        let a = vec!["The file crates/zeph-memory/src/compaction_probe.rs was modified.".into()];
663        let (_, avg) = score_answers(&q, &a);
664        // Substring boost should fire: all expected tokens present in actual.
665        assert!(
666            avg > 0.8,
667            "expected high score for file path match, got {avg}"
668        );
669    }
670
671    #[test]
672    fn score_unicode_input() {
673        let q = vec![ProbeQuestion {
674            question: "Что было изменено?".into(),
675            expected_answer: "файл config.toml".into(),
676            ..Default::default()
677        }];
678        let a = vec!["config.toml был изменён".into()];
679        // Just verify no panic; score may vary.
680        let (scores, _) = score_answers(&q, &a);
681        assert_eq!(scores.len(), 1);
682    }
683
684    // --- verdict threshold tests ---
685
686    #[test]
687    fn verdict_thresholds() {
688        let config = CompactionProbeConfig::default();
689
690        // Pass >= 0.6
691        let score = 0.7_f32;
692        let verdict = if score >= config.threshold {
693            ProbeVerdict::Pass
694        } else if score >= config.hard_fail_threshold {
695            ProbeVerdict::SoftFail
696        } else {
697            ProbeVerdict::HardFail
698        };
699        assert_eq!(verdict, ProbeVerdict::Pass);
700
701        // SoftFail [0.35, 0.6)
702        let score = 0.5_f32;
703        let verdict = if score >= config.threshold {
704            ProbeVerdict::Pass
705        } else if score >= config.hard_fail_threshold {
706            ProbeVerdict::SoftFail
707        } else {
708            ProbeVerdict::HardFail
709        };
710        assert_eq!(verdict, ProbeVerdict::SoftFail);
711
712        // HardFail < 0.35
713        let score = 0.2_f32;
714        let verdict = if score >= config.threshold {
715            ProbeVerdict::Pass
716        } else if score >= config.hard_fail_threshold {
717            ProbeVerdict::SoftFail
718        } else {
719            ProbeVerdict::HardFail
720        };
721        assert_eq!(verdict, ProbeVerdict::HardFail);
722    }
723
724    // --- config defaults ---
725
726    #[test]
727    fn config_defaults() {
728        let c = CompactionProbeConfig::default();
729        assert!(!c.enabled);
730        assert!(c.probe_provider.is_none());
731        assert!((c.threshold - 0.6).abs() < 0.001);
732        assert!((c.hard_fail_threshold - 0.35).abs() < 0.001);
733        assert_eq!(c.max_questions, 5);
734        assert_eq!(c.timeout_secs, 15);
735        assert!(c.category_weights.is_none());
736    }
737
738    // --- serde round-trip ---
739
740    #[test]
741    fn config_serde_round_trip() {
742        let original = CompactionProbeConfig {
743            enabled: true,
744            probe_provider: Some(zeph_config::ProviderName::new("fast")),
745            threshold: 0.65,
746            hard_fail_threshold: 0.4,
747            max_questions: 5,
748            timeout_secs: 20,
749            category_weights: None,
750        };
751        let json = serde_json::to_string(&original).expect("serialize");
752        let restored: CompactionProbeConfig = serde_json::from_str(&json).expect("deserialize");
753        assert!(restored.enabled);
754        assert_eq!(restored.probe_provider.as_deref(), Some("fast"));
755        assert!((restored.threshold - 0.65).abs() < 0.001);
756    }
757
758    #[test]
759    fn probe_result_serde_round_trip() {
760        let result = CompactionProbeResult {
761            score: 0.75,
762            category_scores: vec![CategoryScore {
763                category: ProbeCategory::Recall,
764                score: 0.75,
765                probes_run: 1,
766            }],
767            questions: vec![ProbeQuestion {
768                question: "What?".into(),
769                expected_answer: "thiserror".into(),
770                category: ProbeCategory::Recall,
771            }],
772            answers: vec!["thiserror".into()],
773            per_question_scores: vec![1.0],
774            verdict: ProbeVerdict::Pass,
775            threshold: 0.6,
776            hard_fail_threshold: 0.35,
777            model: "haiku".into(),
778            duration_ms: 1234,
779        };
780        let json = serde_json::to_string(&result).expect("serialize");
781        let restored: CompactionProbeResult = serde_json::from_str(&json).expect("deserialize");
782        assert!((restored.score - 0.75).abs() < 0.001);
783        assert_eq!(restored.verdict, ProbeVerdict::Pass);
784        assert_eq!(restored.category_scores.len(), 1);
785    }
786
787    #[test]
788    fn probe_result_backward_compat_no_category_scores() {
789        // Old JSON without category_scores field must deserialize with empty vec.
790        let json = r#"{"score":0.75,"questions":[],"answers":[],"per_question_scores":[],"verdict":"Pass","threshold":0.6,"hard_fail_threshold":0.35,"model":"haiku","duration_ms":0}"#;
791        let restored: CompactionProbeResult = serde_json::from_str(json).expect("deserialize");
792        assert!(restored.category_scores.is_empty());
793    }
794
795    // --- fewer answers than questions (LLM returned truncated list) ---
796
797    #[test]
798    fn score_fewer_answers_than_questions() {
799        let questions = vec![
800            ProbeQuestion {
801                question: "What crate?".into(),
802                expected_answer: "thiserror".into(),
803                ..Default::default()
804            },
805            ProbeQuestion {
806                question: "What file?".into(),
807                expected_answer: "src/lib.rs".into(),
808                ..Default::default()
809            },
810            ProbeQuestion {
811                question: "What decision?".into(),
812                expected_answer: "use async traits".into(),
813                ..Default::default()
814            },
815        ];
816        // LLM only returned 1 answer for 3 questions.
817        let answers = vec!["thiserror".into()];
818        let (scores, avg) = score_answers(&questions, &answers);
819        // scores must have the same length as questions (missing answers → empty string → ~0).
820        assert_eq!(scores.len(), 3);
821        // First answer is a perfect match.
822        assert!(
823            (scores[0] - 1.0).abs() < 0.01,
824            "first score should be ~1.0, got {}",
825            scores[0]
826        );
827        // Missing answers score 0 (empty string vs non-empty expected).
828        assert!(
829            scores[1] < 0.5,
830            "second score should be low for missing answer, got {}",
831            scores[1]
832        );
833        assert!(
834            scores[2] < 0.5,
835            "third score should be low for missing answer, got {}",
836            scores[2]
837        );
838        // Average is dragged down by the two missing answers.
839        assert!(
840            avg < 0.5,
841            "average should be below 0.5 with 2 missing answers, got {avg}"
842        );
843    }
844
845    // --- exact boundary values for threshold ---
846
847    #[test]
848    fn verdict_boundary_at_threshold() {
849        let config = CompactionProbeConfig::default();
850
851        // Exactly at pass threshold → Pass.
852        let score = config.threshold;
853        let verdict = if score >= config.threshold {
854            ProbeVerdict::Pass
855        } else if score >= config.hard_fail_threshold {
856            ProbeVerdict::SoftFail
857        } else {
858            ProbeVerdict::HardFail
859        };
860        assert_eq!(verdict, ProbeVerdict::Pass);
861
862        // One ULP below pass threshold, above hard-fail → SoftFail.
863        let score = config.threshold - f32::EPSILON;
864        let verdict = if score >= config.threshold {
865            ProbeVerdict::Pass
866        } else if score >= config.hard_fail_threshold {
867            ProbeVerdict::SoftFail
868        } else {
869            ProbeVerdict::HardFail
870        };
871        assert_eq!(verdict, ProbeVerdict::SoftFail);
872
873        // Exactly at hard-fail threshold → SoftFail (boundary is inclusive).
874        let score = config.hard_fail_threshold;
875        let verdict = if score >= config.threshold {
876            ProbeVerdict::Pass
877        } else if score >= config.hard_fail_threshold {
878            ProbeVerdict::SoftFail
879        } else {
880            ProbeVerdict::HardFail
881        };
882        assert_eq!(verdict, ProbeVerdict::SoftFail);
883
884        // One ULP below hard-fail threshold → HardFail.
885        let score = config.hard_fail_threshold - f32::EPSILON;
886        let verdict = if score >= config.threshold {
887            ProbeVerdict::Pass
888        } else if score >= config.hard_fail_threshold {
889            ProbeVerdict::SoftFail
890        } else {
891            ProbeVerdict::HardFail
892        };
893        assert_eq!(verdict, ProbeVerdict::HardFail);
894    }
895
896    // --- config partial deserialization (serde default fields) ---
897
898    #[test]
899    fn config_partial_json_uses_defaults() {
900        // Only `enabled` is specified; all other fields must fall back to defaults via #[serde(default)].
901        let json = r#"{"enabled": true}"#;
902        let c: CompactionProbeConfig =
903            serde_json::from_str(json).expect("deserialize partial json");
904        assert!(c.enabled);
905        assert!(c.probe_provider.is_none());
906        assert!((c.threshold - 0.6).abs() < 0.001);
907        assert!((c.hard_fail_threshold - 0.35).abs() < 0.001);
908        assert_eq!(c.max_questions, 5);
909        assert_eq!(c.timeout_secs, 15);
910    }
911
912    #[test]
913    fn config_empty_json_uses_all_defaults() {
914        let c: CompactionProbeConfig = serde_json::from_str("{}").expect("deserialize empty json");
915        assert!(!c.enabled);
916        assert!(c.probe_provider.is_none());
917    }
918
919    // --- ProbeCategory serde ---
920
921    #[test]
922    fn probe_category_serde_lowercase() {
923        assert_eq!(
924            serde_json::to_string(&ProbeCategory::Recall).unwrap(),
925            r#""recall""#
926        );
927        assert_eq!(
928            serde_json::to_string(&ProbeCategory::Artifact).unwrap(),
929            r#""artifact""#
930        );
931        assert_eq!(
932            serde_json::to_string(&ProbeCategory::Continuation).unwrap(),
933            r#""continuation""#
934        );
935        assert_eq!(
936            serde_json::to_string(&ProbeCategory::Decision).unwrap(),
937            r#""decision""#
938        );
939        let cat: ProbeCategory = serde_json::from_str(r#""recall""#).unwrap();
940        assert_eq!(cat, ProbeCategory::Recall);
941    }
942
943    // --- category_weights TOML compat ---
944
945    #[test]
946    fn category_weights_toml_round_trip() {
947        let toml_str = r#"
948enabled = true
949probe_provider = "fast"
950threshold = 0.6
951hard_fail_threshold = 0.35
952max_questions = 5
953timeout_secs = 15
954[category_weights]
955recall = 1.5
956artifact = 1.0
957continuation = 1.0
958decision = 0.8
959"#;
960        let c: CompactionProbeConfig = toml::from_str(toml_str).expect("deserialize toml");
961        let weights = c.category_weights.as_ref().unwrap();
962        assert!((weights[&ProbeCategory::Recall] - 1.5).abs() < 0.001);
963        assert!((weights[&ProbeCategory::Decision] - 0.8).abs() < 0.001);
964    }
965
966    // --- compute_category_scores ---
967
968    #[test]
969    fn category_scores_equal_weights() {
970        let questions = vec![
971            ProbeQuestion {
972                question: "Q1".into(),
973                expected_answer: "A1".into(),
974                category: ProbeCategory::Recall,
975            },
976            ProbeQuestion {
977                question: "Q2".into(),
978                expected_answer: "A2".into(),
979                category: ProbeCategory::Artifact,
980            },
981        ];
982        let scores = [1.0_f32, 0.0_f32];
983        let (cats, overall) = compute_category_scores(&questions, &scores, None);
984        assert_eq!(cats.len(), 2);
985        // Equal weight: (1.0 + 0.0) / 2 = 0.5
986        assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
987    }
988
989    #[test]
990    fn category_scores_missing_category_excluded() {
991        // Only Recall and Decision present; Artifact/Continuation absent.
992        let questions = vec![
993            ProbeQuestion {
994                question: "Q1".into(),
995                expected_answer: "A1".into(),
996                category: ProbeCategory::Recall,
997            },
998            ProbeQuestion {
999                question: "Q2".into(),
1000                expected_answer: "A2".into(),
1001                category: ProbeCategory::Decision,
1002            },
1003        ];
1004        let scores = [1.0_f32, 0.6_f32];
1005        let (cats, _overall) = compute_category_scores(&questions, &scores, None);
1006        assert_eq!(cats.len(), 2, "only categories with questions present");
1007        let categories: Vec<_> = cats.iter().map(|c| c.category).collect();
1008        assert!(!categories.contains(&ProbeCategory::Artifact));
1009        assert!(!categories.contains(&ProbeCategory::Continuation));
1010    }
1011
1012    #[test]
1013    fn category_scores_custom_weights() {
1014        let questions = vec![
1015            ProbeQuestion {
1016                question: "Q1".into(),
1017                expected_answer: "A1".into(),
1018                category: ProbeCategory::Recall,
1019            },
1020            ProbeQuestion {
1021                question: "Q2".into(),
1022                expected_answer: "A2".into(),
1023                category: ProbeCategory::Decision,
1024            },
1025        ];
1026        let scores = [1.0_f32, 0.0_f32];
1027        let mut weights = HashMap::new();
1028        weights.insert(ProbeCategory::Recall, 2.0_f32);
1029        weights.insert(ProbeCategory::Decision, 1.0_f32);
1030        let (_, overall) = compute_category_scores(&questions, &scores, Some(&weights));
1031        // (1.0*2 + 0.0*1) / (2+1) = 0.666..
1032        assert!(
1033            (overall - 2.0 / 3.0).abs() < 0.001,
1034            "expected ~0.667, got {overall}"
1035        );
1036    }
1037
1038    #[test]
1039    fn category_scores_all_zero_weights_fallback() {
1040        let questions = vec![
1041            ProbeQuestion {
1042                question: "Q1".into(),
1043                expected_answer: "A1".into(),
1044                category: ProbeCategory::Recall,
1045            },
1046            ProbeQuestion {
1047                question: "Q2".into(),
1048                expected_answer: "A2".into(),
1049                category: ProbeCategory::Artifact,
1050            },
1051        ];
1052        let scores = [1.0_f32, 0.0_f32];
1053        let mut weights = HashMap::new();
1054        weights.insert(ProbeCategory::Recall, 0.0_f32);
1055        weights.insert(ProbeCategory::Artifact, 0.0_f32);
1056        let (_, overall) = compute_category_scores(&questions, &scores, Some(&weights));
1057        // Fallback to equal weighting: (1.0 + 0.0) / 2 = 0.5
1058        assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
1059    }
1060
1061    #[test]
1062    fn category_scores_empty_questions() {
1063        let (cats, overall) = compute_category_scores(&[], &[], None);
1064        assert!(cats.is_empty());
1065        assert!((overall - 0.0).abs() < 0.001);
1066    }
1067
1068    #[test]
1069    fn category_scores_multi_probe_single_category_averages() {
1070        // Three Recall questions with scores 1.0, 0.0, 0.5 → average 0.5.
1071        let questions = vec![
1072            ProbeQuestion {
1073                question: "Q1".into(),
1074                expected_answer: "A1".into(),
1075                category: ProbeCategory::Recall,
1076            },
1077            ProbeQuestion {
1078                question: "Q2".into(),
1079                expected_answer: "A2".into(),
1080                category: ProbeCategory::Recall,
1081            },
1082            ProbeQuestion {
1083                question: "Q3".into(),
1084                expected_answer: "A3".into(),
1085                category: ProbeCategory::Recall,
1086            },
1087        ];
1088        let scores = [1.0_f32, 0.0_f32, 0.5_f32];
1089        let (cats, overall) = compute_category_scores(&questions, &scores, None);
1090        assert_eq!(cats.len(), 1, "only one category present");
1091        assert_eq!(cats[0].category, ProbeCategory::Recall);
1092        assert_eq!(cats[0].probes_run, 3);
1093        assert!(
1094            (cats[0].score - 0.5).abs() < 0.001,
1095            "cat score={}",
1096            cats[0].score
1097        );
1098        // With one category, overall equals that category's average.
1099        assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
1100    }
1101
1102    #[test]
1103    fn probe_question_serde_default_category() {
1104        // Old JSON without `category` field must deserialize to Recall via #[serde(default)].
1105        let json = r#"{"question":"What file?","expected_answer":"src/lib.rs"}"#;
1106        let q: ProbeQuestion = serde_json::from_str(json).expect("deserialize");
1107        assert_eq!(q.category, ProbeCategory::Recall);
1108        assert_eq!(q.question, "What file?");
1109        assert_eq!(q.expected_answer, "src/lib.rs");
1110    }
1111
1112    #[test]
1113    fn probe_question_serde_all_categories_round_trip() {
1114        for cat in [
1115            ProbeCategory::Recall,
1116            ProbeCategory::Artifact,
1117            ProbeCategory::Continuation,
1118            ProbeCategory::Decision,
1119        ] {
1120            let q = ProbeQuestion {
1121                question: "test?".into(),
1122                expected_answer: "answer".into(),
1123                category: cat,
1124            };
1125            let json = serde_json::to_string(&q).expect("serialize");
1126            let restored: ProbeQuestion = serde_json::from_str(&json).expect("deserialize");
1127            assert_eq!(restored.category, cat);
1128        }
1129    }
1130}