Skip to main content

zeph_memory/
compaction_probe.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Compaction probe: validates summary quality before committing it to the context.
5//!
6//! Generates factual questions from the messages being compacted, then answers them
7//! using only the summary text, and scores the answers against expected values.
8//! Returns a [`CompactionProbeResult`] that the caller uses to decide whether to
9//! commit or reject the summary.
10
11use std::collections::HashMap;
12use std::time::Instant;
13
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16use zeph_llm::any::AnyProvider;
17use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, MessagePart, Role};
18
19use crate::error::MemoryError;
20
21// --- Data structures ---
22
23use zeph_config::memory::{CompactionProbeConfig, ProbeCategory};
24
25/// Per-category scoring breakdown from a compaction probe run.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct CategoryScore {
28    pub category: ProbeCategory,
29    /// Average score of all questions in this category, in [0.0, 1.0].
30    pub score: f32,
31    /// Number of questions generated for this category.
32    pub probes_run: u32,
33}
34
35/// A single factual question with the expected answer.
36#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
37pub struct ProbeQuestion {
38    /// Factual question about the compacted messages.
39    pub question: String,
40    /// Expected correct answer extractable from the original messages.
41    pub expected_answer: String,
42    /// Functional category of this question.
43    #[serde(default = "default_probe_category")]
44    pub category: ProbeCategory,
45}
46
47fn default_probe_category() -> ProbeCategory {
48    ProbeCategory::Recall
49}
50
51impl Default for ProbeQuestion {
52    fn default() -> Self {
53        Self {
54            question: String::new(),
55            expected_answer: String::new(),
56            category: ProbeCategory::Recall,
57        }
58    }
59}
60
61/// Three-tier verdict for compaction probe quality.
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
63#[non_exhaustive]
64pub enum ProbeVerdict {
65    /// Score >= `threshold`: summary preserves enough context. Proceed.
66    Pass,
67    /// Score in [`hard_fail_threshold`, `threshold`): summary is borderline.
68    /// Proceed with compaction but log a warning.
69    SoftFail,
70    /// Score < `hard_fail_threshold`: summary lost critical facts. Block compaction.
71    HardFail,
72    /// Transport/timeout failure — no quality score produced.
73    Error,
74}
75
76/// Full result of a compaction probe run.
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct CompactionProbeResult {
79    /// Overall score in [0.0, 1.0].
80    pub score: f32,
81    /// Per-category breakdown. Categories with 0 questions are omitted.
82    #[serde(default)]
83    pub category_scores: Vec<CategoryScore>,
84    /// Per-question breakdown.
85    pub questions: Vec<ProbeQuestion>,
86    /// LLM answers to the questions (positionally aligned with `questions`).
87    pub answers: Vec<String>,
88    /// Per-question similarity scores.
89    pub per_question_scores: Vec<f32>,
90    pub verdict: ProbeVerdict,
91    /// Pass threshold used for this run.
92    pub threshold: f32,
93    /// Hard-fail threshold used for this run.
94    pub hard_fail_threshold: f32,
95    /// Model name used for the probe.
96    pub model: String,
97    /// Wall-clock duration in milliseconds.
98    pub duration_ms: u64,
99}
100
101// --- Structured LLM output types ---
102
103#[derive(Debug, Deserialize, JsonSchema)]
104struct ProbeQuestionsOutput {
105    questions: Vec<ProbeQuestion>,
106}
107
108// --- Category scoring ---
109
110/// Group per-question scores by category and compute per-category averages.
111///
112/// Categories with no questions are excluded from the returned list.
113#[must_use]
114fn compute_category_scores(
115    questions: &[ProbeQuestion],
116    per_question_scores: &[f32],
117    category_weights: Option<&HashMap<ProbeCategory, f32>>,
118) -> (Vec<CategoryScore>, f32) {
119    // Group scores by category.
120    let mut by_cat: HashMap<ProbeCategory, Vec<f32>> = HashMap::new();
121    for (q, &s) in questions.iter().zip(per_question_scores.iter()) {
122        by_cat.entry(q.category).or_default().push(s);
123    }
124
125    #[allow(clippy::cast_precision_loss)]
126    let category_scores: Vec<CategoryScore> = by_cat
127        .into_iter()
128        .map(|(category, scores)| {
129            let avg = scores.iter().sum::<f32>() / scores.len() as f32;
130            CategoryScore {
131                category,
132                score: avg,
133                #[allow(clippy::cast_possible_truncation)]
134                probes_run: scores.len() as u32,
135            }
136        })
137        .collect();
138
139    if category_scores.is_empty() {
140        return (category_scores, 0.0);
141    }
142
143    // Compute weighted overall score. Default: equal weights.
144    let mut weighted_sum = 0.0_f32;
145    let mut weight_total = 0.0_f32;
146    for cs in &category_scores {
147        let raw_w = category_weights
148            .and_then(|m| m.get(&cs.category).copied())
149            .unwrap_or(1.0);
150        if raw_w < 0.0 {
151            tracing::warn!(
152                category = ?cs.category,
153                weight = raw_w,
154                "category_weights contains a negative value — treating as 0.0 (category excluded from scoring)"
155            );
156        }
157        let w = raw_w.max(0.0);
158        weighted_sum += cs.score * w;
159        weight_total += w;
160    }
161
162    let overall = if weight_total > 0.0 {
163        weighted_sum / weight_total
164    } else {
165        // All weights are 0 — fall back to equal weighting.
166        #[allow(clippy::cast_precision_loss)]
167        let n = category_scores.len() as f32;
168        category_scores.iter().map(|cs| cs.score).sum::<f32>() / n
169    };
170
171    (category_scores, overall)
172}
173
174#[derive(Debug, Deserialize, JsonSchema)]
175struct ProbeAnswersOutput {
176    answers: Vec<String>,
177}
178
179// --- Scoring ---
180
181/// Refusal indicators: if the actual answer contains any of these, score it 0.0.
182const REFUSAL_PATTERNS: &[&str] = &[
183    "unknown",
184    "not mentioned",
185    "not found",
186    "n/a",
187    "cannot determine",
188    "no information",
189    "not provided",
190    "not specified",
191    "not stated",
192    "not available",
193];
194
195fn is_refusal(text: &str) -> bool {
196    let lower = text.to_lowercase();
197    REFUSAL_PATTERNS.iter().any(|p| lower.contains(p))
198}
199
200/// Normalize a string: lowercase, split on non-alphanumeric chars, keep tokens >= 3 chars.
201fn normalize_tokens(text: &str) -> Vec<String> {
202    text.to_lowercase()
203        .split(|c: char| !c.is_alphanumeric())
204        .filter(|t| t.len() >= 3)
205        .map(String::from)
206        .collect()
207}
208
209fn jaccard(a: &[String], b: &[String]) -> f32 {
210    if a.is_empty() && b.is_empty() {
211        return 1.0;
212    }
213    let set_a: std::collections::HashSet<&str> = a.iter().map(String::as_str).collect();
214    let set_b: std::collections::HashSet<&str> = b.iter().map(String::as_str).collect();
215    let intersection = set_a.intersection(&set_b).count();
216    let union = set_a.union(&set_b).count();
217    if union == 0 {
218        return 0.0;
219    }
220    #[allow(clippy::cast_precision_loss)]
221    {
222        intersection as f32 / union as f32
223    }
224}
225
226/// Score a single (expected, actual) answer pair using token-set-ratio.
227fn score_pair(expected: &str, actual: &str) -> f32 {
228    if is_refusal(actual) {
229        return 0.0;
230    }
231
232    let tokens_e = normalize_tokens(expected);
233    let tokens_a = normalize_tokens(actual);
234
235    // Substring boost: if all expected tokens appear in actual, it's an exact match.
236    if !tokens_e.is_empty() {
237        let set_e: std::collections::HashSet<&str> = tokens_e.iter().map(String::as_str).collect();
238        let set_a: std::collections::HashSet<&str> = tokens_a.iter().map(String::as_str).collect();
239        if set_e.is_subset(&set_a) {
240            return 1.0;
241        }
242    }
243
244    // Token-set-ratio: max of three Jaccard variants.
245    let j_full = jaccard(&tokens_e, &tokens_a);
246
247    // Intersection with each set individually (handles subset relationships).
248    let set_e: std::collections::HashSet<&str> = tokens_e.iter().map(String::as_str).collect();
249    let set_a: std::collections::HashSet<&str> = tokens_a.iter().map(String::as_str).collect();
250    let intersection: Vec<String> = set_e
251        .intersection(&set_a)
252        .map(|s| (*s).to_owned())
253        .collect();
254
255    #[allow(clippy::cast_precision_loss)]
256    let j_e = if tokens_e.is_empty() {
257        0.0_f32
258    } else {
259        intersection.len() as f32 / tokens_e.len() as f32
260    };
261    #[allow(clippy::cast_precision_loss)]
262    let j_a = if tokens_a.is_empty() {
263        0.0_f32
264    } else {
265        intersection.len() as f32 / tokens_a.len() as f32
266    };
267
268    j_full.max(j_e).max(j_a)
269}
270
271/// Score answers against expected values using token-set-ratio similarity.
272///
273/// Returns `(per_question_scores, overall_average)`.
274#[must_use]
275pub fn score_answers(questions: &[ProbeQuestion], answers: &[String]) -> (Vec<f32>, f32) {
276    if questions.is_empty() {
277        return (vec![], 0.0);
278    }
279    let scores: Vec<f32> = questions
280        .iter()
281        .zip(answers.iter().chain(std::iter::repeat(&String::new())))
282        .map(|(q, a)| score_pair(&q.expected_answer, a))
283        .collect();
284    #[allow(clippy::cast_precision_loss)]
285    let avg = if scores.is_empty() {
286        0.0
287    } else {
288        scores.iter().sum::<f32>() / scores.len() as f32
289    };
290    (scores, avg)
291}
292
293// --- LLM calls ---
294
295/// Truncate tool-result bodies to 500 chars to avoid flooding the probe with raw output.
296fn truncate_tool_bodies(messages: &[Message]) -> Vec<Message> {
297    messages
298        .iter()
299        .map(|m| {
300            let mut msg = m.clone();
301            for part in &mut msg.parts {
302                if let MessagePart::ToolOutput { body, .. } = part {
303                    if body.len() <= 500 {
304                        continue;
305                    }
306                    body.truncate(500);
307                    body.push('\u{2026}');
308                }
309            }
310            msg.rebuild_content();
311            msg
312        })
313        .collect()
314}
315
316/// Generate factual probe questions from the messages being compacted.
317///
318/// Uses a single LLM call with structured output. Tool-result bodies are
319/// truncated to 500 chars to focus on decisions and outcomes rather than raw tool output.
320///
321/// # Errors
322///
323/// Returns `MemoryError::Llm` if the LLM call fails.
324#[cfg_attr(
325    feature = "profiling",
326    tracing::instrument(name = "memory.compaction_probe", skip_all)
327)]
328pub async fn generate_probe_questions(
329    provider: &AnyProvider,
330    messages: &[Message],
331    max_questions: usize,
332) -> Result<Vec<ProbeQuestion>, MemoryError> {
333    let truncated = truncate_tool_bodies(messages);
334
335    let mut history = String::new();
336    for msg in &truncated {
337        let role = match msg.role {
338            Role::Assistant => "assistant",
339            Role::System => "system",
340            Role::User | _ => "user",
341        };
342        history.push_str(role);
343        history.push_str(": ");
344        history.push_str(&msg.content);
345        history.push('\n');
346    }
347
348    let prompt = format!(
349        "Given the following conversation excerpt, generate {max_questions} factual questions \
350         that test whether a summary preserves the most important concrete details.\n\
351         \n\
352         You MUST generate at least one question per category when max_questions >= 4. \
353         If the conversation lacks information for a category, generate a question noting that absence.\n\
354         \n\
355         Categories:\n\
356         - recall: Specific facts that survived (file paths, function names, values). \
357           Example: \"What file was modified?\"\n\
358         - artifact: Which files/tools/URLs the agent used. \
359           Example: \"Which tool was executed?\"\n\
360         - continuation: Next steps, blockers, open questions. \
361           Example: \"What is the next step?\"\n\
362         - decision: Past reasoning traces (why X over Y, trade-offs). \
363           Example: \"Why was X chosen over Y?\"\n\
364         \n\
365         Do NOT generate questions about:\n\
366         - Raw tool output content (compiler warnings, test output line numbers)\n\
367         - Intermediate debugging steps that were superseded\n\
368         - Opinions or reasoning that cannot be verified\n\
369         \n\
370         Each question must have a single unambiguous expected answer extractable from the text.\n\
371         \n\
372         Conversation:\n{history}\n\
373         \n\
374         Respond in JSON with schema: {{\"questions\": [{{\"question\": \"...\", \
375         \"expected_answer\": \"...\", \"category\": \"recall|artifact|continuation|decision\"}}]}}"
376    );
377
378    let msgs = [Message {
379        role: Role::User,
380        content: prompt,
381        parts: vec![],
382        metadata: MessageMetadata::default(),
383    }];
384
385    let mut output: ProbeQuestionsOutput = provider
386        .chat_typed_erased::<ProbeQuestionsOutput>(&msgs)
387        .await
388        .map_err(MemoryError::Llm)?;
389
390    // Cap the list to max_questions: a misbehaving LLM could return more.
391    output.questions.truncate(max_questions);
392
393    Ok(output.questions)
394}
395
396/// Answer probe questions using only the compaction summary as context.
397///
398/// # Errors
399///
400/// Returns `MemoryError::Llm` if the LLM call fails.
401pub async fn answer_probe_questions(
402    provider: &AnyProvider,
403    summary: &str,
404    questions: &[ProbeQuestion],
405) -> Result<Vec<String>, MemoryError> {
406    let mut numbered = String::new();
407    for (i, q) in questions.iter().enumerate() {
408        use std::fmt::Write as _;
409        let _ = writeln!(numbered, "{}. {}", i + 1, q.question);
410    }
411
412    let prompt = format!(
413        "Given the following summary of a conversation, answer each question using ONLY \
414         information present in the summary. If the answer is not in the summary, respond \
415         with \"UNKNOWN\".\n\
416         \n\
417         Summary:\n{summary}\n\
418         \n\
419         Questions:\n{numbered}\n\
420         \n\
421         Respond in JSON with schema: {{\"answers\": [\"answer1\", \"answer2\", ...]}}"
422    );
423
424    let msgs = [Message {
425        role: Role::User,
426        content: prompt,
427        parts: vec![],
428        metadata: MessageMetadata::default(),
429    }];
430
431    let output: ProbeAnswersOutput = provider
432        .chat_typed_erased::<ProbeAnswersOutput>(&msgs)
433        .await
434        .map_err(MemoryError::Llm)?;
435
436    Ok(output.answers)
437}
438
439/// Run the compaction probe: generate questions, answer them from the summary, score results.
440///
441/// Returns `Ok(None)` when:
442/// - Probe is disabled (`config.enabled = false`)
443/// - The probe times out
444/// - Fewer than 2 questions are generated (insufficient statistical power)
445///
446/// The caller treats `None` as "no opinion" and proceeds with compaction.
447///
448/// # Errors
449///
450/// Returns `MemoryError` if an LLM call fails. Callers should treat this as non-fatal
451/// and proceed with compaction.
452pub async fn validate_compaction(
453    provider: AnyProvider,
454    messages: Vec<Message>,
455    summary: String,
456    config: &CompactionProbeConfig,
457) -> Result<Option<CompactionProbeResult>, MemoryError> {
458    if !config.enabled {
459        return Ok(None);
460    }
461
462    let timeout = std::time::Duration::from_secs(config.timeout_secs);
463    let start = Instant::now();
464
465    let result = tokio::time::timeout(timeout, async {
466        run_probe(provider, messages, summary, config).await
467    })
468    .await;
469
470    match result {
471        Ok(inner) => inner,
472        Err(_elapsed) => {
473            tracing::warn!(
474                timeout_secs = config.timeout_secs,
475                "compaction probe timed out — proceeding with compaction"
476            );
477            Ok(None)
478        }
479    }
480    .map(|opt| {
481        opt.map(|mut r| {
482            r.duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
483            r
484        })
485    })
486}
487
488async fn run_probe(
489    provider: AnyProvider,
490    messages: Vec<Message>,
491    summary: String,
492    config: &CompactionProbeConfig,
493) -> Result<Option<CompactionProbeResult>, MemoryError> {
494    if summary.len() < 10 {
495        tracing::warn!(
496            len = summary.len(),
497            "compaction probe: summary too short — skipping probe"
498        );
499        return Ok(None);
500    }
501
502    let questions = generate_probe_questions(&provider, &messages, config.max_questions).await?;
503
504    if questions.len() < 2 {
505        tracing::debug!(
506            count = questions.len(),
507            "compaction probe: fewer than 2 questions generated — skipping probe"
508        );
509        return Ok(None);
510    }
511
512    // Warn if any category is missing when we expected full coverage.
513    if config.max_questions >= 4 {
514        use std::collections::HashSet;
515        let covered: HashSet<_> = questions.iter().map(|q| q.category).collect();
516        for cat in [
517            ProbeCategory::Recall,
518            ProbeCategory::Artifact,
519            ProbeCategory::Continuation,
520            ProbeCategory::Decision,
521        ] {
522            if !covered.contains(&cat) {
523                tracing::warn!(
524                    category = ?cat,
525                    "compaction probe: LLM did not generate questions for category"
526                );
527            }
528        }
529    }
530
531    let answers = answer_probe_questions(&provider, &summary, &questions).await?;
532
533    let (per_question_scores, _simple_avg) = score_answers(&questions, &answers);
534
535    let (category_scores, score) = compute_category_scores(
536        &questions,
537        &per_question_scores,
538        config.category_weights.as_ref(),
539    );
540
541    let verdict = if score >= config.threshold {
542        ProbeVerdict::Pass
543    } else if score >= config.hard_fail_threshold {
544        ProbeVerdict::SoftFail
545    } else {
546        ProbeVerdict::HardFail
547    };
548
549    let model = provider.name().to_owned();
550
551    Ok(Some(CompactionProbeResult {
552        score,
553        category_scores,
554        questions,
555        answers,
556        per_question_scores,
557        verdict,
558        threshold: config.threshold,
559        hard_fail_threshold: config.hard_fail_threshold,
560        model,
561        duration_ms: 0, // filled in by validate_compaction
562    }))
563}
564
565#[cfg(test)]
566mod tests {
567    use super::*;
568
569    // --- score_answers tests ---
570
571    #[test]
572    fn score_perfect_match() {
573        let q = vec![ProbeQuestion {
574            question: "What crate is used?".into(),
575            expected_answer: "thiserror".into(),
576            category: ProbeCategory::Recall,
577        }];
578        let a = vec!["thiserror".into()];
579        let (scores, avg) = score_answers(&q, &a);
580        assert_eq!(scores.len(), 1);
581        assert!((avg - 1.0).abs() < 0.01, "expected ~1.0, got {avg}");
582    }
583
584    #[test]
585    fn score_complete_mismatch() {
586        let q = vec![ProbeQuestion {
587            question: "What file was modified?".into(),
588            expected_answer: "src/auth.rs".into(),
589            ..Default::default()
590        }];
591        let a = vec!["definitely not in the summary".into()];
592        let (scores, avg) = score_answers(&q, &a);
593        assert_eq!(scores.len(), 1);
594        // Very low overlap expected.
595        assert!(avg < 0.5, "expected low score, got {avg}");
596    }
597
598    #[test]
599    fn score_refusal_is_zero() {
600        let q = vec![ProbeQuestion {
601            question: "What was the decision?".into(),
602            expected_answer: "Use thiserror for typed errors".into(),
603            ..Default::default()
604        }];
605        for refusal in &[
606            "UNKNOWN",
607            "not mentioned",
608            "N/A",
609            "cannot determine",
610            "No information",
611        ] {
612            let a = vec![(*refusal).to_owned()];
613            let (_, avg) = score_answers(&q, &a);
614            assert!(avg < 0.01, "expected 0 for refusal '{refusal}', got {avg}");
615        }
616    }
617
618    #[test]
619    fn score_paraphrased_answer_above_half() {
620        // "thiserror was chosen for error types" vs "Use thiserror for typed errors"
621        // Shared tokens: "thiserror", "error" (and maybe "for"/"types"/"typed" with >=3 chars)
622        let q = vec![ProbeQuestion {
623            question: "What error handling crate was chosen?".into(),
624            expected_answer: "Use thiserror for typed errors in library crates".into(),
625            ..Default::default()
626        }];
627        let a = vec!["thiserror was chosen for error types in library crates".into()];
628        let (_, avg) = score_answers(&q, &a);
629        assert!(avg > 0.5, "expected >0.5 for paraphrase, got {avg}");
630    }
631
632    #[test]
633    fn score_empty_strings() {
634        let q = vec![ProbeQuestion {
635            question: "What?".into(),
636            expected_answer: String::new(),
637            ..Default::default()
638        }];
639        let a = vec![String::new()];
640        let (scores, avg) = score_answers(&q, &a);
641        assert_eq!(scores.len(), 1);
642        // Both empty — jaccard of two empty sets returns 1.0 (exact match).
643        assert!(
644            (avg - 1.0).abs() < 0.01,
645            "expected 1.0 for empty vs empty, got {avg}"
646        );
647    }
648
649    #[test]
650    fn score_empty_questions_list() {
651        let (scores, avg) = score_answers(&[], &[]);
652        assert!(scores.is_empty());
653        assert!((avg - 0.0).abs() < 0.01);
654    }
655
656    #[test]
657    fn score_file_path_exact() {
658        let q = vec![ProbeQuestion {
659            question: "Which file was modified?".into(),
660            expected_answer: "crates/zeph-memory/src/compaction_probe.rs".into(),
661            ..Default::default()
662        }];
663        let a = vec!["The file crates/zeph-memory/src/compaction_probe.rs was modified.".into()];
664        let (_, avg) = score_answers(&q, &a);
665        // Substring boost should fire: all expected tokens present in actual.
666        assert!(
667            avg > 0.8,
668            "expected high score for file path match, got {avg}"
669        );
670    }
671
672    #[test]
673    fn score_unicode_input() {
674        let q = vec![ProbeQuestion {
675            question: "Что было изменено?".into(),
676            expected_answer: "файл config.toml".into(),
677            ..Default::default()
678        }];
679        let a = vec!["config.toml был изменён".into()];
680        // Just verify no panic; score may vary.
681        let (scores, _) = score_answers(&q, &a);
682        assert_eq!(scores.len(), 1);
683    }
684
685    // --- verdict threshold tests ---
686
687    #[test]
688    fn verdict_thresholds() {
689        let config = CompactionProbeConfig::default();
690
691        // Pass >= 0.6
692        let score = 0.7_f32;
693        let verdict = if score >= config.threshold {
694            ProbeVerdict::Pass
695        } else if score >= config.hard_fail_threshold {
696            ProbeVerdict::SoftFail
697        } else {
698            ProbeVerdict::HardFail
699        };
700        assert_eq!(verdict, ProbeVerdict::Pass);
701
702        // SoftFail [0.35, 0.6)
703        let score = 0.5_f32;
704        let verdict = if score >= config.threshold {
705            ProbeVerdict::Pass
706        } else if score >= config.hard_fail_threshold {
707            ProbeVerdict::SoftFail
708        } else {
709            ProbeVerdict::HardFail
710        };
711        assert_eq!(verdict, ProbeVerdict::SoftFail);
712
713        // HardFail < 0.35
714        let score = 0.2_f32;
715        let verdict = if score >= config.threshold {
716            ProbeVerdict::Pass
717        } else if score >= config.hard_fail_threshold {
718            ProbeVerdict::SoftFail
719        } else {
720            ProbeVerdict::HardFail
721        };
722        assert_eq!(verdict, ProbeVerdict::HardFail);
723    }
724
725    // --- config defaults ---
726
727    #[test]
728    fn config_defaults() {
729        let c = CompactionProbeConfig::default();
730        assert!(!c.enabled);
731        assert!(c.probe_provider.is_none());
732        assert!((c.threshold - 0.6).abs() < 0.001);
733        assert!((c.hard_fail_threshold - 0.35).abs() < 0.001);
734        assert_eq!(c.max_questions, 5);
735        assert_eq!(c.timeout_secs, 15);
736        assert!(c.category_weights.is_none());
737    }
738
739    // --- serde round-trip ---
740
741    #[test]
742    fn config_serde_round_trip() {
743        let original = CompactionProbeConfig {
744            enabled: true,
745            probe_provider: Some(zeph_config::ProviderName::new("fast")),
746            threshold: 0.65,
747            hard_fail_threshold: 0.4,
748            max_questions: 5,
749            timeout_secs: 20,
750            category_weights: None,
751        };
752        let json = serde_json::to_string(&original).expect("serialize");
753        let restored: CompactionProbeConfig = serde_json::from_str(&json).expect("deserialize");
754        assert!(restored.enabled);
755        assert_eq!(
756            restored
757                .probe_provider
758                .as_ref()
759                .map(zeph_config::ProviderName::as_str),
760            Some("fast")
761        );
762        assert!((restored.threshold - 0.65).abs() < 0.001);
763    }
764
765    #[test]
766    fn probe_result_serde_round_trip() {
767        let result = CompactionProbeResult {
768            score: 0.75,
769            category_scores: vec![CategoryScore {
770                category: ProbeCategory::Recall,
771                score: 0.75,
772                probes_run: 1,
773            }],
774            questions: vec![ProbeQuestion {
775                question: "What?".into(),
776                expected_answer: "thiserror".into(),
777                category: ProbeCategory::Recall,
778            }],
779            answers: vec!["thiserror".into()],
780            per_question_scores: vec![1.0],
781            verdict: ProbeVerdict::Pass,
782            threshold: 0.6,
783            hard_fail_threshold: 0.35,
784            model: "haiku".into(),
785            duration_ms: 1234,
786        };
787        let json = serde_json::to_string(&result).expect("serialize");
788        let restored: CompactionProbeResult = serde_json::from_str(&json).expect("deserialize");
789        assert!((restored.score - 0.75).abs() < 0.001);
790        assert_eq!(restored.verdict, ProbeVerdict::Pass);
791        assert_eq!(restored.category_scores.len(), 1);
792    }
793
794    #[test]
795    fn probe_result_backward_compat_no_category_scores() {
796        // Old JSON without category_scores field must deserialize with empty vec.
797        let json = r#"{"score":0.75,"questions":[],"answers":[],"per_question_scores":[],"verdict":"Pass","threshold":0.6,"hard_fail_threshold":0.35,"model":"haiku","duration_ms":0}"#;
798        let restored: CompactionProbeResult = serde_json::from_str(json).expect("deserialize");
799        assert!(restored.category_scores.is_empty());
800    }
801
802    // --- fewer answers than questions (LLM returned truncated list) ---
803
804    #[test]
805    fn score_fewer_answers_than_questions() {
806        let questions = vec![
807            ProbeQuestion {
808                question: "What crate?".into(),
809                expected_answer: "thiserror".into(),
810                ..Default::default()
811            },
812            ProbeQuestion {
813                question: "What file?".into(),
814                expected_answer: "src/lib.rs".into(),
815                ..Default::default()
816            },
817            ProbeQuestion {
818                question: "What decision?".into(),
819                expected_answer: "use async traits".into(),
820                ..Default::default()
821            },
822        ];
823        // LLM only returned 1 answer for 3 questions.
824        let answers = vec!["thiserror".into()];
825        let (scores, avg) = score_answers(&questions, &answers);
826        // scores must have the same length as questions (missing answers → empty string → ~0).
827        assert_eq!(scores.len(), 3);
828        // First answer is a perfect match.
829        assert!(
830            (scores[0] - 1.0).abs() < 0.01,
831            "first score should be ~1.0, got {}",
832            scores[0]
833        );
834        // Missing answers score 0 (empty string vs non-empty expected).
835        assert!(
836            scores[1] < 0.5,
837            "second score should be low for missing answer, got {}",
838            scores[1]
839        );
840        assert!(
841            scores[2] < 0.5,
842            "third score should be low for missing answer, got {}",
843            scores[2]
844        );
845        // Average is dragged down by the two missing answers.
846        assert!(
847            avg < 0.5,
848            "average should be below 0.5 with 2 missing answers, got {avg}"
849        );
850    }
851
852    // --- exact boundary values for threshold ---
853
854    #[test]
855    fn verdict_boundary_at_threshold() {
856        let config = CompactionProbeConfig::default();
857
858        // Exactly at pass threshold → Pass.
859        let score = config.threshold;
860        let verdict = if score >= config.threshold {
861            ProbeVerdict::Pass
862        } else if score >= config.hard_fail_threshold {
863            ProbeVerdict::SoftFail
864        } else {
865            ProbeVerdict::HardFail
866        };
867        assert_eq!(verdict, ProbeVerdict::Pass);
868
869        // One ULP below pass threshold, above hard-fail → SoftFail.
870        let score = config.threshold - f32::EPSILON;
871        let verdict = if score >= config.threshold {
872            ProbeVerdict::Pass
873        } else if score >= config.hard_fail_threshold {
874            ProbeVerdict::SoftFail
875        } else {
876            ProbeVerdict::HardFail
877        };
878        assert_eq!(verdict, ProbeVerdict::SoftFail);
879
880        // Exactly at hard-fail threshold → SoftFail (boundary is inclusive).
881        let score = config.hard_fail_threshold;
882        let verdict = if score >= config.threshold {
883            ProbeVerdict::Pass
884        } else if score >= config.hard_fail_threshold {
885            ProbeVerdict::SoftFail
886        } else {
887            ProbeVerdict::HardFail
888        };
889        assert_eq!(verdict, ProbeVerdict::SoftFail);
890
891        // One ULP below hard-fail threshold → HardFail.
892        let score = config.hard_fail_threshold - f32::EPSILON;
893        let verdict = if score >= config.threshold {
894            ProbeVerdict::Pass
895        } else if score >= config.hard_fail_threshold {
896            ProbeVerdict::SoftFail
897        } else {
898            ProbeVerdict::HardFail
899        };
900        assert_eq!(verdict, ProbeVerdict::HardFail);
901    }
902
903    // --- config partial deserialization (serde default fields) ---
904
905    #[test]
906    fn config_partial_json_uses_defaults() {
907        // Only `enabled` is specified; all other fields must fall back to defaults via #[serde(default)].
908        let json = r#"{"enabled": true}"#;
909        let c: CompactionProbeConfig =
910            serde_json::from_str(json).expect("deserialize partial json");
911        assert!(c.enabled);
912        assert!(c.probe_provider.is_none());
913        assert!((c.threshold - 0.6).abs() < 0.001);
914        assert!((c.hard_fail_threshold - 0.35).abs() < 0.001);
915        assert_eq!(c.max_questions, 5);
916        assert_eq!(c.timeout_secs, 15);
917    }
918
919    #[test]
920    fn config_empty_json_uses_all_defaults() {
921        let c: CompactionProbeConfig = serde_json::from_str("{}").expect("deserialize empty json");
922        assert!(!c.enabled);
923        assert!(c.probe_provider.is_none());
924    }
925
926    // --- ProbeCategory serde ---
927
928    #[test]
929    fn probe_category_serde_lowercase() {
930        assert_eq!(
931            serde_json::to_string(&ProbeCategory::Recall).unwrap(),
932            r#""recall""#
933        );
934        assert_eq!(
935            serde_json::to_string(&ProbeCategory::Artifact).unwrap(),
936            r#""artifact""#
937        );
938        assert_eq!(
939            serde_json::to_string(&ProbeCategory::Continuation).unwrap(),
940            r#""continuation""#
941        );
942        assert_eq!(
943            serde_json::to_string(&ProbeCategory::Decision).unwrap(),
944            r#""decision""#
945        );
946        let cat: ProbeCategory = serde_json::from_str(r#""recall""#).unwrap();
947        assert_eq!(cat, ProbeCategory::Recall);
948    }
949
950    // --- category_weights TOML compat ---
951
952    #[test]
953    fn category_weights_toml_round_trip() {
954        let toml_str = r#"
955enabled = true
956probe_provider = "fast"
957threshold = 0.6
958hard_fail_threshold = 0.35
959max_questions = 5
960timeout_secs = 15
961[category_weights]
962recall = 1.5
963artifact = 1.0
964continuation = 1.0
965decision = 0.8
966"#;
967        let c: CompactionProbeConfig = toml::from_str(toml_str).expect("deserialize toml");
968        let weights = c.category_weights.as_ref().unwrap();
969        assert!((weights[&ProbeCategory::Recall] - 1.5).abs() < 0.001);
970        assert!((weights[&ProbeCategory::Decision] - 0.8).abs() < 0.001);
971    }
972
973    // --- compute_category_scores ---
974
975    #[test]
976    fn category_scores_equal_weights() {
977        let questions = vec![
978            ProbeQuestion {
979                question: "Q1".into(),
980                expected_answer: "A1".into(),
981                category: ProbeCategory::Recall,
982            },
983            ProbeQuestion {
984                question: "Q2".into(),
985                expected_answer: "A2".into(),
986                category: ProbeCategory::Artifact,
987            },
988        ];
989        let scores = [1.0_f32, 0.0_f32];
990        let (cats, overall) = compute_category_scores(&questions, &scores, None);
991        assert_eq!(cats.len(), 2);
992        // Equal weight: (1.0 + 0.0) / 2 = 0.5
993        assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
994    }
995
996    #[test]
997    fn category_scores_missing_category_excluded() {
998        // Only Recall and Decision present; Artifact/Continuation absent.
999        let questions = vec![
1000            ProbeQuestion {
1001                question: "Q1".into(),
1002                expected_answer: "A1".into(),
1003                category: ProbeCategory::Recall,
1004            },
1005            ProbeQuestion {
1006                question: "Q2".into(),
1007                expected_answer: "A2".into(),
1008                category: ProbeCategory::Decision,
1009            },
1010        ];
1011        let scores = [1.0_f32, 0.6_f32];
1012        let (cats, _overall) = compute_category_scores(&questions, &scores, None);
1013        assert_eq!(cats.len(), 2, "only categories with questions present");
1014        let categories: Vec<_> = cats.iter().map(|c| c.category).collect();
1015        assert!(!categories.contains(&ProbeCategory::Artifact));
1016        assert!(!categories.contains(&ProbeCategory::Continuation));
1017    }
1018
1019    #[test]
1020    fn category_scores_custom_weights() {
1021        let questions = vec![
1022            ProbeQuestion {
1023                question: "Q1".into(),
1024                expected_answer: "A1".into(),
1025                category: ProbeCategory::Recall,
1026            },
1027            ProbeQuestion {
1028                question: "Q2".into(),
1029                expected_answer: "A2".into(),
1030                category: ProbeCategory::Decision,
1031            },
1032        ];
1033        let scores = [1.0_f32, 0.0_f32];
1034        let mut weights = HashMap::new();
1035        weights.insert(ProbeCategory::Recall, 2.0_f32);
1036        weights.insert(ProbeCategory::Decision, 1.0_f32);
1037        let (_, overall) = compute_category_scores(&questions, &scores, Some(&weights));
1038        // (1.0*2 + 0.0*1) / (2+1) = 0.666..
1039        assert!(
1040            (overall - 2.0 / 3.0).abs() < 0.001,
1041            "expected ~0.667, got {overall}"
1042        );
1043    }
1044
1045    #[test]
1046    fn category_scores_all_zero_weights_fallback() {
1047        let questions = vec![
1048            ProbeQuestion {
1049                question: "Q1".into(),
1050                expected_answer: "A1".into(),
1051                category: ProbeCategory::Recall,
1052            },
1053            ProbeQuestion {
1054                question: "Q2".into(),
1055                expected_answer: "A2".into(),
1056                category: ProbeCategory::Artifact,
1057            },
1058        ];
1059        let scores = [1.0_f32, 0.0_f32];
1060        let mut weights = HashMap::new();
1061        weights.insert(ProbeCategory::Recall, 0.0_f32);
1062        weights.insert(ProbeCategory::Artifact, 0.0_f32);
1063        let (_, overall) = compute_category_scores(&questions, &scores, Some(&weights));
1064        // Fallback to equal weighting: (1.0 + 0.0) / 2 = 0.5
1065        assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
1066    }
1067
1068    #[test]
1069    fn category_scores_empty_questions() {
1070        let (cats, overall) = compute_category_scores(&[], &[], None);
1071        assert!(cats.is_empty());
1072        assert!((overall - 0.0).abs() < 0.001);
1073    }
1074
1075    #[test]
1076    fn category_scores_multi_probe_single_category_averages() {
1077        // Three Recall questions with scores 1.0, 0.0, 0.5 → average 0.5.
1078        let questions = vec![
1079            ProbeQuestion {
1080                question: "Q1".into(),
1081                expected_answer: "A1".into(),
1082                category: ProbeCategory::Recall,
1083            },
1084            ProbeQuestion {
1085                question: "Q2".into(),
1086                expected_answer: "A2".into(),
1087                category: ProbeCategory::Recall,
1088            },
1089            ProbeQuestion {
1090                question: "Q3".into(),
1091                expected_answer: "A3".into(),
1092                category: ProbeCategory::Recall,
1093            },
1094        ];
1095        let scores = [1.0_f32, 0.0_f32, 0.5_f32];
1096        let (cats, overall) = compute_category_scores(&questions, &scores, None);
1097        assert_eq!(cats.len(), 1, "only one category present");
1098        assert_eq!(cats[0].category, ProbeCategory::Recall);
1099        assert_eq!(cats[0].probes_run, 3);
1100        assert!(
1101            (cats[0].score - 0.5).abs() < 0.001,
1102            "cat score={}",
1103            cats[0].score
1104        );
1105        // With one category, overall equals that category's average.
1106        assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
1107    }
1108
1109    #[test]
1110    fn probe_question_serde_default_category() {
1111        // Old JSON without `category` field must deserialize to Recall via #[serde(default)].
1112        let json = r#"{"question":"What file?","expected_answer":"src/lib.rs"}"#;
1113        let q: ProbeQuestion = serde_json::from_str(json).expect("deserialize");
1114        assert_eq!(q.category, ProbeCategory::Recall);
1115        assert_eq!(q.question, "What file?");
1116        assert_eq!(q.expected_answer, "src/lib.rs");
1117    }
1118
1119    #[test]
1120    fn probe_question_serde_all_categories_round_trip() {
1121        for cat in [
1122            ProbeCategory::Recall,
1123            ProbeCategory::Artifact,
1124            ProbeCategory::Continuation,
1125            ProbeCategory::Decision,
1126        ] {
1127            let q = ProbeQuestion {
1128                question: "test?".into(),
1129                expected_answer: "answer".into(),
1130                category: cat,
1131            };
1132            let json = serde_json::to_string(&q).expect("serialize");
1133            let restored: ProbeQuestion = serde_json::from_str(&json).expect("deserialize");
1134            assert_eq!(restored.category, cat);
1135        }
1136    }
1137}