Skip to main content

zeph_memory/
compaction_probe.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Compaction probe: validates summary quality before committing it to the context.
5//!
6//! Generates factual questions from the messages being compacted, then answers them
7//! using only the summary text, and scores the answers against expected values.
8//! Returns a [`CompactionProbeResult`] that the caller uses to decide whether to
9//! commit or reject the summary.
10
11use std::collections::HashMap;
12use std::time::Instant;
13
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16use zeph_llm::any::AnyProvider;
17use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, MessagePart, Role};
18
19use crate::error::MemoryError;
20
21// --- Data structures ---
22
23/// Functional category of a probe question.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
25#[serde(rename_all = "lowercase")]
26pub enum ProbeCategory {
27    /// Did specific facts survive? (file paths, function names, values, decisions)
28    Recall,
29    /// Does the agent know which files/tools/URLs it used?
30    Artifact,
31    /// Can it pick up mid-task? (current step, next steps, blockers, open questions)
32    Continuation,
33    /// Are past reasoning traces intact? (why X over Y, trade-offs, constraints)
34    Decision,
35}
36
37/// Per-category scoring breakdown from a compaction probe run.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct CategoryScore {
40    pub category: ProbeCategory,
41    /// Average score of all questions in this category, in [0.0, 1.0].
42    pub score: f32,
43    /// Number of questions generated for this category.
44    pub probes_run: u32,
45}
46
47/// A single factual question with the expected answer.
48#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
49pub struct ProbeQuestion {
50    /// Factual question about the compacted messages.
51    pub question: String,
52    /// Expected correct answer extractable from the original messages.
53    pub expected_answer: String,
54    /// Functional category of this question.
55    #[serde(default = "default_probe_category")]
56    pub category: ProbeCategory,
57}
58
59fn default_probe_category() -> ProbeCategory {
60    ProbeCategory::Recall
61}
62
63impl Default for ProbeQuestion {
64    fn default() -> Self {
65        Self {
66            question: String::new(),
67            expected_answer: String::new(),
68            category: ProbeCategory::Recall,
69        }
70    }
71}
72
73/// Three-tier verdict for compaction probe quality.
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
75pub enum ProbeVerdict {
76    /// Score >= `threshold`: summary preserves enough context. Proceed.
77    Pass,
78    /// Score in [`hard_fail_threshold`, `threshold`): summary is borderline.
79    /// Proceed with compaction but log a warning.
80    SoftFail,
81    /// Score < `hard_fail_threshold`: summary lost critical facts. Block compaction.
82    HardFail,
83    /// Transport/timeout failure — no quality score produced.
84    Error,
85}
86
87/// Full result of a compaction probe run.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct CompactionProbeResult {
90    /// Overall score in [0.0, 1.0].
91    pub score: f32,
92    /// Per-category breakdown. Categories with 0 questions are omitted.
93    #[serde(default)]
94    pub category_scores: Vec<CategoryScore>,
95    /// Per-question breakdown.
96    pub questions: Vec<ProbeQuestion>,
97    /// LLM answers to the questions (positionally aligned with `questions`).
98    pub answers: Vec<String>,
99    /// Per-question similarity scores.
100    pub per_question_scores: Vec<f32>,
101    pub verdict: ProbeVerdict,
102    /// Pass threshold used for this run.
103    pub threshold: f32,
104    /// Hard-fail threshold used for this run.
105    pub hard_fail_threshold: f32,
106    /// Model name used for the probe.
107    pub model: String,
108    /// Wall-clock duration in milliseconds.
109    pub duration_ms: u64,
110}
111
112// --- Structured LLM output types ---
113
114#[derive(Debug, Deserialize, JsonSchema)]
115struct ProbeQuestionsOutput {
116    questions: Vec<ProbeQuestion>,
117}
118
119// --- Category scoring ---
120
121/// Group per-question scores by category and compute per-category averages.
122///
123/// Categories with no questions are excluded from the returned list.
124#[must_use]
125fn compute_category_scores(
126    questions: &[ProbeQuestion],
127    per_question_scores: &[f32],
128    category_weights: Option<&HashMap<ProbeCategory, f32>>,
129) -> (Vec<CategoryScore>, f32) {
130    // Group scores by category.
131    let mut by_cat: HashMap<ProbeCategory, Vec<f32>> = HashMap::new();
132    for (q, &s) in questions.iter().zip(per_question_scores.iter()) {
133        by_cat.entry(q.category).or_default().push(s);
134    }
135
136    #[allow(clippy::cast_precision_loss)]
137    let category_scores: Vec<CategoryScore> = by_cat
138        .into_iter()
139        .map(|(category, scores)| {
140            let avg = scores.iter().sum::<f32>() / scores.len() as f32;
141            CategoryScore {
142                category,
143                score: avg,
144                #[allow(clippy::cast_possible_truncation)]
145                probes_run: scores.len() as u32,
146            }
147        })
148        .collect();
149
150    if category_scores.is_empty() {
151        return (category_scores, 0.0);
152    }
153
154    // Compute weighted overall score. Default: equal weights.
155    let mut weighted_sum = 0.0_f32;
156    let mut weight_total = 0.0_f32;
157    for cs in &category_scores {
158        let raw_w = category_weights
159            .and_then(|m| m.get(&cs.category).copied())
160            .unwrap_or(1.0);
161        if raw_w < 0.0 {
162            tracing::warn!(
163                category = ?cs.category,
164                weight = raw_w,
165                "category_weights contains a negative value — treating as 0.0 (category excluded from scoring)"
166            );
167        }
168        let w = raw_w.max(0.0);
169        weighted_sum += cs.score * w;
170        weight_total += w;
171    }
172
173    let overall = if weight_total > 0.0 {
174        weighted_sum / weight_total
175    } else {
176        // All weights are 0 — fall back to equal weighting.
177        #[allow(clippy::cast_precision_loss)]
178        let n = category_scores.len() as f32;
179        category_scores.iter().map(|cs| cs.score).sum::<f32>() / n
180    };
181
182    (category_scores, overall)
183}
184
185#[derive(Debug, Deserialize, JsonSchema)]
186struct ProbeAnswersOutput {
187    answers: Vec<String>,
188}
189
190// --- Scoring ---
191
192/// Refusal indicators: if the actual answer contains any of these, score it 0.0.
193const REFUSAL_PATTERNS: &[&str] = &[
194    "unknown",
195    "not mentioned",
196    "not found",
197    "n/a",
198    "cannot determine",
199    "no information",
200    "not provided",
201    "not specified",
202    "not stated",
203    "not available",
204];
205
206fn is_refusal(text: &str) -> bool {
207    let lower = text.to_lowercase();
208    REFUSAL_PATTERNS.iter().any(|p| lower.contains(p))
209}
210
211/// Normalize a string: lowercase, split on non-alphanumeric chars, keep tokens >= 3 chars.
212fn normalize_tokens(text: &str) -> Vec<String> {
213    text.to_lowercase()
214        .split(|c: char| !c.is_alphanumeric())
215        .filter(|t| t.len() >= 3)
216        .map(String::from)
217        .collect()
218}
219
220fn jaccard(a: &[String], b: &[String]) -> f32 {
221    if a.is_empty() && b.is_empty() {
222        return 1.0;
223    }
224    let set_a: std::collections::HashSet<&str> = a.iter().map(String::as_str).collect();
225    let set_b: std::collections::HashSet<&str> = b.iter().map(String::as_str).collect();
226    let intersection = set_a.intersection(&set_b).count();
227    let union = set_a.union(&set_b).count();
228    if union == 0 {
229        return 0.0;
230    }
231    #[allow(clippy::cast_precision_loss)]
232    {
233        intersection as f32 / union as f32
234    }
235}
236
237/// Score a single (expected, actual) answer pair using token-set-ratio.
238fn score_pair(expected: &str, actual: &str) -> f32 {
239    if is_refusal(actual) {
240        return 0.0;
241    }
242
243    let tokens_e = normalize_tokens(expected);
244    let tokens_a = normalize_tokens(actual);
245
246    // Substring boost: if all expected tokens appear in actual, it's an exact match.
247    if !tokens_e.is_empty() {
248        let set_e: std::collections::HashSet<&str> = tokens_e.iter().map(String::as_str).collect();
249        let set_a: std::collections::HashSet<&str> = tokens_a.iter().map(String::as_str).collect();
250        if set_e.is_subset(&set_a) {
251            return 1.0;
252        }
253    }
254
255    // Token-set-ratio: max of three Jaccard variants.
256    let j_full = jaccard(&tokens_e, &tokens_a);
257
258    // Intersection with each set individually (handles subset relationships).
259    let set_e: std::collections::HashSet<&str> = tokens_e.iter().map(String::as_str).collect();
260    let set_a: std::collections::HashSet<&str> = tokens_a.iter().map(String::as_str).collect();
261    let intersection: Vec<String> = set_e
262        .intersection(&set_a)
263        .map(|s| (*s).to_owned())
264        .collect();
265
266    #[allow(clippy::cast_precision_loss)]
267    let j_e = if tokens_e.is_empty() {
268        0.0_f32
269    } else {
270        intersection.len() as f32 / tokens_e.len() as f32
271    };
272    #[allow(clippy::cast_precision_loss)]
273    let j_a = if tokens_a.is_empty() {
274        0.0_f32
275    } else {
276        intersection.len() as f32 / tokens_a.len() as f32
277    };
278
279    j_full.max(j_e).max(j_a)
280}
281
282/// Score answers against expected values using token-set-ratio similarity.
283///
284/// Returns `(per_question_scores, overall_average)`.
285#[must_use]
286pub fn score_answers(questions: &[ProbeQuestion], answers: &[String]) -> (Vec<f32>, f32) {
287    if questions.is_empty() {
288        return (vec![], 0.0);
289    }
290    let scores: Vec<f32> = questions
291        .iter()
292        .zip(answers.iter().chain(std::iter::repeat(&String::new())))
293        .map(|(q, a)| score_pair(&q.expected_answer, a))
294        .collect();
295    #[allow(clippy::cast_precision_loss)]
296    let avg = if scores.is_empty() {
297        0.0
298    } else {
299        scores.iter().sum::<f32>() / scores.len() as f32
300    };
301    (scores, avg)
302}
303
304// --- LLM calls ---
305
306/// Truncate tool-result bodies to 500 chars to avoid flooding the probe with raw output.
307fn truncate_tool_bodies(messages: &[Message]) -> Vec<Message> {
308    messages
309        .iter()
310        .map(|m| {
311            let mut msg = m.clone();
312            for part in &mut msg.parts {
313                if let MessagePart::ToolOutput { body, .. } = part {
314                    if body.len() <= 500 {
315                        continue;
316                    }
317                    body.truncate(500);
318                    body.push('\u{2026}');
319                }
320            }
321            msg.rebuild_content();
322            msg
323        })
324        .collect()
325}
326
327/// Generate factual probe questions from the messages being compacted.
328///
329/// Uses a single LLM call with structured output. Tool-result bodies are
330/// truncated to 500 chars to focus on decisions and outcomes rather than raw tool output.
331///
332/// # Errors
333///
334/// Returns `MemoryError::Llm` if the LLM call fails.
335#[cfg_attr(
336    feature = "profiling",
337    tracing::instrument(name = "memory.compaction_probe", skip_all)
338)]
339pub async fn generate_probe_questions(
340    provider: &AnyProvider,
341    messages: &[Message],
342    max_questions: usize,
343) -> Result<Vec<ProbeQuestion>, MemoryError> {
344    let truncated = truncate_tool_bodies(messages);
345
346    let mut history = String::new();
347    for msg in &truncated {
348        let role = match msg.role {
349            Role::User => "user",
350            Role::Assistant => "assistant",
351            Role::System => "system",
352        };
353        history.push_str(role);
354        history.push_str(": ");
355        history.push_str(&msg.content);
356        history.push('\n');
357    }
358
359    let prompt = format!(
360        "Given the following conversation excerpt, generate {max_questions} factual questions \
361         that test whether a summary preserves the most important concrete details.\n\
362         \n\
363         You MUST generate at least one question per category when max_questions >= 4. \
364         If the conversation lacks information for a category, generate a question noting that absence.\n\
365         \n\
366         Categories:\n\
367         - recall: Specific facts that survived (file paths, function names, values). \
368           Example: \"What file was modified?\"\n\
369         - artifact: Which files/tools/URLs the agent used. \
370           Example: \"Which tool was executed?\"\n\
371         - continuation: Next steps, blockers, open questions. \
372           Example: \"What is the next step?\"\n\
373         - decision: Past reasoning traces (why X over Y, trade-offs). \
374           Example: \"Why was X chosen over Y?\"\n\
375         \n\
376         Do NOT generate questions about:\n\
377         - Raw tool output content (compiler warnings, test output line numbers)\n\
378         - Intermediate debugging steps that were superseded\n\
379         - Opinions or reasoning that cannot be verified\n\
380         \n\
381         Each question must have a single unambiguous expected answer extractable from the text.\n\
382         \n\
383         Conversation:\n{history}\n\
384         \n\
385         Respond in JSON with schema: {{\"questions\": [{{\"question\": \"...\", \
386         \"expected_answer\": \"...\", \"category\": \"recall|artifact|continuation|decision\"}}]}}"
387    );
388
389    let msgs = [Message {
390        role: Role::User,
391        content: prompt,
392        parts: vec![],
393        metadata: MessageMetadata::default(),
394    }];
395
396    let mut output: ProbeQuestionsOutput = provider
397        .chat_typed_erased::<ProbeQuestionsOutput>(&msgs)
398        .await
399        .map_err(MemoryError::Llm)?;
400
401    // Cap the list to max_questions: a misbehaving LLM could return more.
402    output.questions.truncate(max_questions);
403
404    Ok(output.questions)
405}
406
407/// Answer probe questions using only the compaction summary as context.
408///
409/// # Errors
410///
411/// Returns `MemoryError::Llm` if the LLM call fails.
412pub async fn answer_probe_questions(
413    provider: &AnyProvider,
414    summary: &str,
415    questions: &[ProbeQuestion],
416) -> Result<Vec<String>, MemoryError> {
417    let mut numbered = String::new();
418    for (i, q) in questions.iter().enumerate() {
419        use std::fmt::Write as _;
420        let _ = writeln!(numbered, "{}. {}", i + 1, q.question);
421    }
422
423    let prompt = format!(
424        "Given the following summary of a conversation, answer each question using ONLY \
425         information present in the summary. If the answer is not in the summary, respond \
426         with \"UNKNOWN\".\n\
427         \n\
428         Summary:\n{summary}\n\
429         \n\
430         Questions:\n{numbered}\n\
431         \n\
432         Respond in JSON with schema: {{\"answers\": [\"answer1\", \"answer2\", ...]}}"
433    );
434
435    let msgs = [Message {
436        role: Role::User,
437        content: prompt,
438        parts: vec![],
439        metadata: MessageMetadata::default(),
440    }];
441
442    let output: ProbeAnswersOutput = provider
443        .chat_typed_erased::<ProbeAnswersOutput>(&msgs)
444        .await
445        .map_err(MemoryError::Llm)?;
446
447    Ok(output.answers)
448}
449
450/// Configuration for the compaction probe.
451#[derive(Debug, Clone, Serialize, Deserialize)]
452#[serde(default)]
453pub struct CompactionProbeConfig {
454    /// Enable compaction probe validation. Default: `false`.
455    pub enabled: bool,
456    /// Provider name from `[[llm.providers]]` for probe LLM calls.
457    /// Empty string = use the summary provider.
458    pub probe_provider: String,
459    /// Minimum score to pass without warnings. Default: `0.6`.
460    /// Scores in [`hard_fail_threshold`, `threshold`) trigger `SoftFail` (warn + proceed).
461    pub threshold: f32,
462    /// Score below this triggers `HardFail` (block compaction). Default: `0.35`.
463    pub hard_fail_threshold: f32,
464    /// Maximum number of probe questions to generate. Default: `5`.
465    pub max_questions: usize,
466    /// Timeout for the entire probe (both LLM calls) in seconds. Default: `15`.
467    pub timeout_secs: u64,
468    /// Optional per-category weight multipliers for the overall score.
469    /// When `None` or empty, all categories are weighted equally.
470    /// Example: `{ recall = 1.5, artifact = 1.0, continuation = 1.0, decision = 0.8 }`
471    #[serde(default)]
472    pub category_weights: Option<HashMap<ProbeCategory, f32>>,
473}
474
475impl Default for CompactionProbeConfig {
476    fn default() -> Self {
477        Self {
478            enabled: false,
479            probe_provider: String::new(),
480            threshold: 0.6,
481            hard_fail_threshold: 0.35,
482            max_questions: 5,
483            timeout_secs: 15,
484            category_weights: None,
485        }
486    }
487}
488
489/// Run the compaction probe: generate questions, answer them from the summary, score results.
490///
491/// Returns `Ok(None)` when:
492/// - Probe is disabled (`config.enabled = false`)
493/// - The probe times out
494/// - Fewer than 2 questions are generated (insufficient statistical power)
495///
496/// The caller treats `None` as "no opinion" and proceeds with compaction.
497///
498/// # Errors
499///
500/// Returns `MemoryError` if an LLM call fails. Callers should treat this as non-fatal
501/// and proceed with compaction.
502pub async fn validate_compaction(
503    provider: AnyProvider,
504    messages: Vec<Message>,
505    summary: String,
506    config: &CompactionProbeConfig,
507) -> Result<Option<CompactionProbeResult>, MemoryError> {
508    if !config.enabled {
509        return Ok(None);
510    }
511
512    let timeout = std::time::Duration::from_secs(config.timeout_secs);
513    let start = Instant::now();
514
515    let result = tokio::time::timeout(timeout, async {
516        run_probe(provider, messages, summary, config).await
517    })
518    .await;
519
520    match result {
521        Ok(inner) => inner,
522        Err(_elapsed) => {
523            tracing::warn!(
524                timeout_secs = config.timeout_secs,
525                "compaction probe timed out — proceeding with compaction"
526            );
527            Ok(None)
528        }
529    }
530    .map(|opt| {
531        opt.map(|mut r| {
532            r.duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
533            r
534        })
535    })
536}
537
538async fn run_probe(
539    provider: AnyProvider,
540    messages: Vec<Message>,
541    summary: String,
542    config: &CompactionProbeConfig,
543) -> Result<Option<CompactionProbeResult>, MemoryError> {
544    if summary.len() < 10 {
545        tracing::warn!(
546            len = summary.len(),
547            "compaction probe: summary too short — skipping probe"
548        );
549        return Ok(None);
550    }
551
552    let questions = generate_probe_questions(&provider, &messages, config.max_questions).await?;
553
554    if questions.len() < 2 {
555        tracing::debug!(
556            count = questions.len(),
557            "compaction probe: fewer than 2 questions generated — skipping probe"
558        );
559        return Ok(None);
560    }
561
562    // Warn if any category is missing when we expected full coverage.
563    if config.max_questions >= 4 {
564        use std::collections::HashSet;
565        let covered: HashSet<_> = questions.iter().map(|q| q.category).collect();
566        for cat in [
567            ProbeCategory::Recall,
568            ProbeCategory::Artifact,
569            ProbeCategory::Continuation,
570            ProbeCategory::Decision,
571        ] {
572            if !covered.contains(&cat) {
573                tracing::warn!(
574                    category = ?cat,
575                    "compaction probe: LLM did not generate questions for category"
576                );
577            }
578        }
579    }
580
581    let answers = answer_probe_questions(&provider, &summary, &questions).await?;
582
583    let (per_question_scores, _simple_avg) = score_answers(&questions, &answers);
584
585    let (category_scores, score) = compute_category_scores(
586        &questions,
587        &per_question_scores,
588        config.category_weights.as_ref(),
589    );
590
591    let verdict = if score >= config.threshold {
592        ProbeVerdict::Pass
593    } else if score >= config.hard_fail_threshold {
594        ProbeVerdict::SoftFail
595    } else {
596        ProbeVerdict::HardFail
597    };
598
599    let model = provider.name().to_owned();
600
601    Ok(Some(CompactionProbeResult {
602        score,
603        category_scores,
604        questions,
605        answers,
606        per_question_scores,
607        verdict,
608        threshold: config.threshold,
609        hard_fail_threshold: config.hard_fail_threshold,
610        model,
611        duration_ms: 0, // filled in by validate_compaction
612    }))
613}
614
615#[cfg(test)]
616mod tests {
617    use super::*;
618
619    // --- score_answers tests ---
620
621    #[test]
622    fn score_perfect_match() {
623        let q = vec![ProbeQuestion {
624            question: "What crate is used?".into(),
625            expected_answer: "thiserror".into(),
626            category: ProbeCategory::Recall,
627        }];
628        let a = vec!["thiserror".into()];
629        let (scores, avg) = score_answers(&q, &a);
630        assert_eq!(scores.len(), 1);
631        assert!((avg - 1.0).abs() < 0.01, "expected ~1.0, got {avg}");
632    }
633
634    #[test]
635    fn score_complete_mismatch() {
636        let q = vec![ProbeQuestion {
637            question: "What file was modified?".into(),
638            expected_answer: "src/auth.rs".into(),
639            ..Default::default()
640        }];
641        let a = vec!["definitely not in the summary".into()];
642        let (scores, avg) = score_answers(&q, &a);
643        assert_eq!(scores.len(), 1);
644        // Very low overlap expected.
645        assert!(avg < 0.5, "expected low score, got {avg}");
646    }
647
648    #[test]
649    fn score_refusal_is_zero() {
650        let q = vec![ProbeQuestion {
651            question: "What was the decision?".into(),
652            expected_answer: "Use thiserror for typed errors".into(),
653            ..Default::default()
654        }];
655        for refusal in &[
656            "UNKNOWN",
657            "not mentioned",
658            "N/A",
659            "cannot determine",
660            "No information",
661        ] {
662            let a = vec![(*refusal).to_owned()];
663            let (_, avg) = score_answers(&q, &a);
664            assert!(avg < 0.01, "expected 0 for refusal '{refusal}', got {avg}");
665        }
666    }
667
668    #[test]
669    fn score_paraphrased_answer_above_half() {
670        // "thiserror was chosen for error types" vs "Use thiserror for typed errors"
671        // Shared tokens: "thiserror", "error" (and maybe "for"/"types"/"typed" with >=3 chars)
672        let q = vec![ProbeQuestion {
673            question: "What error handling crate was chosen?".into(),
674            expected_answer: "Use thiserror for typed errors in library crates".into(),
675            ..Default::default()
676        }];
677        let a = vec!["thiserror was chosen for error types in library crates".into()];
678        let (_, avg) = score_answers(&q, &a);
679        assert!(avg > 0.5, "expected >0.5 for paraphrase, got {avg}");
680    }
681
682    #[test]
683    fn score_empty_strings() {
684        let q = vec![ProbeQuestion {
685            question: "What?".into(),
686            expected_answer: String::new(),
687            ..Default::default()
688        }];
689        let a = vec![String::new()];
690        let (scores, avg) = score_answers(&q, &a);
691        assert_eq!(scores.len(), 1);
692        // Both empty — jaccard of two empty sets returns 1.0 (exact match).
693        assert!(
694            (avg - 1.0).abs() < 0.01,
695            "expected 1.0 for empty vs empty, got {avg}"
696        );
697    }
698
699    #[test]
700    fn score_empty_questions_list() {
701        let (scores, avg) = score_answers(&[], &[]);
702        assert!(scores.is_empty());
703        assert!((avg - 0.0).abs() < 0.01);
704    }
705
706    #[test]
707    fn score_file_path_exact() {
708        let q = vec![ProbeQuestion {
709            question: "Which file was modified?".into(),
710            expected_answer: "crates/zeph-memory/src/compaction_probe.rs".into(),
711            ..Default::default()
712        }];
713        let a = vec!["The file crates/zeph-memory/src/compaction_probe.rs was modified.".into()];
714        let (_, avg) = score_answers(&q, &a);
715        // Substring boost should fire: all expected tokens present in actual.
716        assert!(
717            avg > 0.8,
718            "expected high score for file path match, got {avg}"
719        );
720    }
721
722    #[test]
723    fn score_unicode_input() {
724        let q = vec![ProbeQuestion {
725            question: "Что было изменено?".into(),
726            expected_answer: "файл config.toml".into(),
727            ..Default::default()
728        }];
729        let a = vec!["config.toml был изменён".into()];
730        // Just verify no panic; score may vary.
731        let (scores, _) = score_answers(&q, &a);
732        assert_eq!(scores.len(), 1);
733    }
734
735    // --- verdict threshold tests ---
736
737    #[test]
738    fn verdict_thresholds() {
739        let config = CompactionProbeConfig::default();
740
741        // Pass >= 0.6
742        let score = 0.7_f32;
743        let verdict = if score >= config.threshold {
744            ProbeVerdict::Pass
745        } else if score >= config.hard_fail_threshold {
746            ProbeVerdict::SoftFail
747        } else {
748            ProbeVerdict::HardFail
749        };
750        assert_eq!(verdict, ProbeVerdict::Pass);
751
752        // SoftFail [0.35, 0.6)
753        let score = 0.5_f32;
754        let verdict = if score >= config.threshold {
755            ProbeVerdict::Pass
756        } else if score >= config.hard_fail_threshold {
757            ProbeVerdict::SoftFail
758        } else {
759            ProbeVerdict::HardFail
760        };
761        assert_eq!(verdict, ProbeVerdict::SoftFail);
762
763        // HardFail < 0.35
764        let score = 0.2_f32;
765        let verdict = if score >= config.threshold {
766            ProbeVerdict::Pass
767        } else if score >= config.hard_fail_threshold {
768            ProbeVerdict::SoftFail
769        } else {
770            ProbeVerdict::HardFail
771        };
772        assert_eq!(verdict, ProbeVerdict::HardFail);
773    }
774
775    // --- config defaults ---
776
777    #[test]
778    fn config_defaults() {
779        let c = CompactionProbeConfig::default();
780        assert!(!c.enabled);
781        assert!(c.probe_provider.is_empty());
782        assert!((c.threshold - 0.6).abs() < 0.001);
783        assert!((c.hard_fail_threshold - 0.35).abs() < 0.001);
784        assert_eq!(c.max_questions, 5);
785        assert_eq!(c.timeout_secs, 15);
786        assert!(c.category_weights.is_none());
787    }
788
789    // --- serde round-trip ---
790
791    #[test]
792    fn config_serde_round_trip() {
793        let original = CompactionProbeConfig {
794            enabled: true,
795            probe_provider: "fast".into(),
796            threshold: 0.65,
797            hard_fail_threshold: 0.4,
798            max_questions: 5,
799            timeout_secs: 20,
800            category_weights: None,
801        };
802        let json = serde_json::to_string(&original).expect("serialize");
803        let restored: CompactionProbeConfig = serde_json::from_str(&json).expect("deserialize");
804        assert!(restored.enabled);
805        assert_eq!(restored.probe_provider, "fast");
806        assert!((restored.threshold - 0.65).abs() < 0.001);
807    }
808
809    #[test]
810    fn probe_result_serde_round_trip() {
811        let result = CompactionProbeResult {
812            score: 0.75,
813            category_scores: vec![CategoryScore {
814                category: ProbeCategory::Recall,
815                score: 0.75,
816                probes_run: 1,
817            }],
818            questions: vec![ProbeQuestion {
819                question: "What?".into(),
820                expected_answer: "thiserror".into(),
821                category: ProbeCategory::Recall,
822            }],
823            answers: vec!["thiserror".into()],
824            per_question_scores: vec![1.0],
825            verdict: ProbeVerdict::Pass,
826            threshold: 0.6,
827            hard_fail_threshold: 0.35,
828            model: "haiku".into(),
829            duration_ms: 1234,
830        };
831        let json = serde_json::to_string(&result).expect("serialize");
832        let restored: CompactionProbeResult = serde_json::from_str(&json).expect("deserialize");
833        assert!((restored.score - 0.75).abs() < 0.001);
834        assert_eq!(restored.verdict, ProbeVerdict::Pass);
835        assert_eq!(restored.category_scores.len(), 1);
836    }
837
838    #[test]
839    fn probe_result_backward_compat_no_category_scores() {
840        // Old JSON without category_scores field must deserialize with empty vec.
841        let json = r#"{"score":0.75,"questions":[],"answers":[],"per_question_scores":[],"verdict":"Pass","threshold":0.6,"hard_fail_threshold":0.35,"model":"haiku","duration_ms":0}"#;
842        let restored: CompactionProbeResult = serde_json::from_str(json).expect("deserialize");
843        assert!(restored.category_scores.is_empty());
844    }
845
846    // --- fewer answers than questions (LLM returned truncated list) ---
847
848    #[test]
849    fn score_fewer_answers_than_questions() {
850        let questions = vec![
851            ProbeQuestion {
852                question: "What crate?".into(),
853                expected_answer: "thiserror".into(),
854                ..Default::default()
855            },
856            ProbeQuestion {
857                question: "What file?".into(),
858                expected_answer: "src/lib.rs".into(),
859                ..Default::default()
860            },
861            ProbeQuestion {
862                question: "What decision?".into(),
863                expected_answer: "use async traits".into(),
864                ..Default::default()
865            },
866        ];
867        // LLM only returned 1 answer for 3 questions.
868        let answers = vec!["thiserror".into()];
869        let (scores, avg) = score_answers(&questions, &answers);
870        // scores must have the same length as questions (missing answers → empty string → ~0).
871        assert_eq!(scores.len(), 3);
872        // First answer is a perfect match.
873        assert!(
874            (scores[0] - 1.0).abs() < 0.01,
875            "first score should be ~1.0, got {}",
876            scores[0]
877        );
878        // Missing answers score 0 (empty string vs non-empty expected).
879        assert!(
880            scores[1] < 0.5,
881            "second score should be low for missing answer, got {}",
882            scores[1]
883        );
884        assert!(
885            scores[2] < 0.5,
886            "third score should be low for missing answer, got {}",
887            scores[2]
888        );
889        // Average is dragged down by the two missing answers.
890        assert!(
891            avg < 0.5,
892            "average should be below 0.5 with 2 missing answers, got {avg}"
893        );
894    }
895
896    // --- exact boundary values for threshold ---
897
898    #[test]
899    fn verdict_boundary_at_threshold() {
900        let config = CompactionProbeConfig::default();
901
902        // Exactly at pass threshold → Pass.
903        let score = config.threshold;
904        let verdict = if score >= config.threshold {
905            ProbeVerdict::Pass
906        } else if score >= config.hard_fail_threshold {
907            ProbeVerdict::SoftFail
908        } else {
909            ProbeVerdict::HardFail
910        };
911        assert_eq!(verdict, ProbeVerdict::Pass);
912
913        // One ULP below pass threshold, above hard-fail → SoftFail.
914        let score = config.threshold - f32::EPSILON;
915        let verdict = if score >= config.threshold {
916            ProbeVerdict::Pass
917        } else if score >= config.hard_fail_threshold {
918            ProbeVerdict::SoftFail
919        } else {
920            ProbeVerdict::HardFail
921        };
922        assert_eq!(verdict, ProbeVerdict::SoftFail);
923
924        // Exactly at hard-fail threshold → SoftFail (boundary is inclusive).
925        let score = config.hard_fail_threshold;
926        let verdict = if score >= config.threshold {
927            ProbeVerdict::Pass
928        } else if score >= config.hard_fail_threshold {
929            ProbeVerdict::SoftFail
930        } else {
931            ProbeVerdict::HardFail
932        };
933        assert_eq!(verdict, ProbeVerdict::SoftFail);
934
935        // One ULP below hard-fail threshold → HardFail.
936        let score = config.hard_fail_threshold - f32::EPSILON;
937        let verdict = if score >= config.threshold {
938            ProbeVerdict::Pass
939        } else if score >= config.hard_fail_threshold {
940            ProbeVerdict::SoftFail
941        } else {
942            ProbeVerdict::HardFail
943        };
944        assert_eq!(verdict, ProbeVerdict::HardFail);
945    }
946
947    // --- config partial deserialization (serde default fields) ---
948
949    #[test]
950    fn config_partial_json_uses_defaults() {
951        // Only `enabled` is specified; all other fields must fall back to defaults via #[serde(default)].
952        let json = r#"{"enabled": true}"#;
953        let c: CompactionProbeConfig =
954            serde_json::from_str(json).expect("deserialize partial json");
955        assert!(c.enabled);
956        assert!(c.probe_provider.is_empty());
957        assert!((c.threshold - 0.6).abs() < 0.001);
958        assert!((c.hard_fail_threshold - 0.35).abs() < 0.001);
959        assert_eq!(c.max_questions, 5);
960        assert_eq!(c.timeout_secs, 15);
961    }
962
963    #[test]
964    fn config_empty_json_uses_all_defaults() {
965        let c: CompactionProbeConfig = serde_json::from_str("{}").expect("deserialize empty json");
966        assert!(!c.enabled);
967        assert!(c.probe_provider.is_empty());
968    }
969
970    // --- ProbeCategory serde ---
971
972    #[test]
973    fn probe_category_serde_lowercase() {
974        assert_eq!(
975            serde_json::to_string(&ProbeCategory::Recall).unwrap(),
976            r#""recall""#
977        );
978        assert_eq!(
979            serde_json::to_string(&ProbeCategory::Artifact).unwrap(),
980            r#""artifact""#
981        );
982        assert_eq!(
983            serde_json::to_string(&ProbeCategory::Continuation).unwrap(),
984            r#""continuation""#
985        );
986        assert_eq!(
987            serde_json::to_string(&ProbeCategory::Decision).unwrap(),
988            r#""decision""#
989        );
990        let cat: ProbeCategory = serde_json::from_str(r#""recall""#).unwrap();
991        assert_eq!(cat, ProbeCategory::Recall);
992    }
993
994    // --- category_weights TOML compat ---
995
996    #[test]
997    fn category_weights_toml_round_trip() {
998        let toml_str = r#"
999enabled = true
1000probe_provider = "fast"
1001threshold = 0.6
1002hard_fail_threshold = 0.35
1003max_questions = 5
1004timeout_secs = 15
1005[category_weights]
1006recall = 1.5
1007artifact = 1.0
1008continuation = 1.0
1009decision = 0.8
1010"#;
1011        let c: CompactionProbeConfig = toml::from_str(toml_str).expect("deserialize toml");
1012        let weights = c.category_weights.as_ref().unwrap();
1013        assert!((weights[&ProbeCategory::Recall] - 1.5).abs() < 0.001);
1014        assert!((weights[&ProbeCategory::Decision] - 0.8).abs() < 0.001);
1015    }
1016
1017    // --- compute_category_scores ---
1018
1019    #[test]
1020    fn category_scores_equal_weights() {
1021        let questions = vec![
1022            ProbeQuestion {
1023                question: "Q1".into(),
1024                expected_answer: "A1".into(),
1025                category: ProbeCategory::Recall,
1026            },
1027            ProbeQuestion {
1028                question: "Q2".into(),
1029                expected_answer: "A2".into(),
1030                category: ProbeCategory::Artifact,
1031            },
1032        ];
1033        let scores = [1.0_f32, 0.0_f32];
1034        let (cats, overall) = compute_category_scores(&questions, &scores, None);
1035        assert_eq!(cats.len(), 2);
1036        // Equal weight: (1.0 + 0.0) / 2 = 0.5
1037        assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
1038    }
1039
1040    #[test]
1041    fn category_scores_missing_category_excluded() {
1042        // Only Recall and Decision present; Artifact/Continuation absent.
1043        let questions = vec![
1044            ProbeQuestion {
1045                question: "Q1".into(),
1046                expected_answer: "A1".into(),
1047                category: ProbeCategory::Recall,
1048            },
1049            ProbeQuestion {
1050                question: "Q2".into(),
1051                expected_answer: "A2".into(),
1052                category: ProbeCategory::Decision,
1053            },
1054        ];
1055        let scores = [1.0_f32, 0.6_f32];
1056        let (cats, _overall) = compute_category_scores(&questions, &scores, None);
1057        assert_eq!(cats.len(), 2, "only categories with questions present");
1058        let categories: Vec<_> = cats.iter().map(|c| c.category).collect();
1059        assert!(!categories.contains(&ProbeCategory::Artifact));
1060        assert!(!categories.contains(&ProbeCategory::Continuation));
1061    }
1062
1063    #[test]
1064    fn category_scores_custom_weights() {
1065        let questions = vec![
1066            ProbeQuestion {
1067                question: "Q1".into(),
1068                expected_answer: "A1".into(),
1069                category: ProbeCategory::Recall,
1070            },
1071            ProbeQuestion {
1072                question: "Q2".into(),
1073                expected_answer: "A2".into(),
1074                category: ProbeCategory::Decision,
1075            },
1076        ];
1077        let scores = [1.0_f32, 0.0_f32];
1078        let mut weights = HashMap::new();
1079        weights.insert(ProbeCategory::Recall, 2.0_f32);
1080        weights.insert(ProbeCategory::Decision, 1.0_f32);
1081        let (_, overall) = compute_category_scores(&questions, &scores, Some(&weights));
1082        // (1.0*2 + 0.0*1) / (2+1) = 0.666..
1083        assert!(
1084            (overall - 2.0 / 3.0).abs() < 0.001,
1085            "expected ~0.667, got {overall}"
1086        );
1087    }
1088
1089    #[test]
1090    fn category_scores_all_zero_weights_fallback() {
1091        let questions = vec![
1092            ProbeQuestion {
1093                question: "Q1".into(),
1094                expected_answer: "A1".into(),
1095                category: ProbeCategory::Recall,
1096            },
1097            ProbeQuestion {
1098                question: "Q2".into(),
1099                expected_answer: "A2".into(),
1100                category: ProbeCategory::Artifact,
1101            },
1102        ];
1103        let scores = [1.0_f32, 0.0_f32];
1104        let mut weights = HashMap::new();
1105        weights.insert(ProbeCategory::Recall, 0.0_f32);
1106        weights.insert(ProbeCategory::Artifact, 0.0_f32);
1107        let (_, overall) = compute_category_scores(&questions, &scores, Some(&weights));
1108        // Fallback to equal weighting: (1.0 + 0.0) / 2 = 0.5
1109        assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
1110    }
1111
1112    #[test]
1113    fn category_scores_empty_questions() {
1114        let (cats, overall) = compute_category_scores(&[], &[], None);
1115        assert!(cats.is_empty());
1116        assert!((overall - 0.0).abs() < 0.001);
1117    }
1118
1119    #[test]
1120    fn category_scores_multi_probe_single_category_averages() {
1121        // Three Recall questions with scores 1.0, 0.0, 0.5 → average 0.5.
1122        let questions = vec![
1123            ProbeQuestion {
1124                question: "Q1".into(),
1125                expected_answer: "A1".into(),
1126                category: ProbeCategory::Recall,
1127            },
1128            ProbeQuestion {
1129                question: "Q2".into(),
1130                expected_answer: "A2".into(),
1131                category: ProbeCategory::Recall,
1132            },
1133            ProbeQuestion {
1134                question: "Q3".into(),
1135                expected_answer: "A3".into(),
1136                category: ProbeCategory::Recall,
1137            },
1138        ];
1139        let scores = [1.0_f32, 0.0_f32, 0.5_f32];
1140        let (cats, overall) = compute_category_scores(&questions, &scores, None);
1141        assert_eq!(cats.len(), 1, "only one category present");
1142        assert_eq!(cats[0].category, ProbeCategory::Recall);
1143        assert_eq!(cats[0].probes_run, 3);
1144        assert!(
1145            (cats[0].score - 0.5).abs() < 0.001,
1146            "cat score={}",
1147            cats[0].score
1148        );
1149        // With one category, overall equals that category's average.
1150        assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
1151    }
1152
1153    #[test]
1154    fn probe_question_serde_default_category() {
1155        // Old JSON without `category` field must deserialize to Recall via #[serde(default)].
1156        let json = r#"{"question":"What file?","expected_answer":"src/lib.rs"}"#;
1157        let q: ProbeQuestion = serde_json::from_str(json).expect("deserialize");
1158        assert_eq!(q.category, ProbeCategory::Recall);
1159        assert_eq!(q.question, "What file?");
1160        assert_eq!(q.expected_answer, "src/lib.rs");
1161    }
1162
1163    #[test]
1164    fn probe_question_serde_all_categories_round_trip() {
1165        for cat in [
1166            ProbeCategory::Recall,
1167            ProbeCategory::Artifact,
1168            ProbeCategory::Continuation,
1169            ProbeCategory::Decision,
1170        ] {
1171            let q = ProbeQuestion {
1172                question: "test?".into(),
1173                expected_answer: "answer".into(),
1174                category: cat,
1175            };
1176            let json = serde_json::to_string(&q).expect("serialize");
1177            let restored: ProbeQuestion = serde_json::from_str(&json).expect("deserialize");
1178            assert_eq!(restored.category, cat);
1179        }
1180    }
1181}