Skip to main content

zeph_memory/
compaction_probe.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Compaction probe: validates summary quality before committing it to the context.
5//!
6//! Generates factual questions from the messages being compacted, then answers them
7//! using only the summary text, and scores the answers against expected values.
8//! Returns a [`CompactionProbeResult`] that the caller uses to decide whether to
9//! commit or reject the summary.
10
11use std::collections::HashMap;
12use std::time::Instant;
13
14use schemars::JsonSchema;
15use serde::{Deserialize, Serialize};
16use zeph_llm::any::AnyProvider;
17use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, MessagePart, Role};
18
19use crate::error::MemoryError;
20
21// --- Data structures ---
22
23/// Functional category of a probe question.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, JsonSchema)]
25#[serde(rename_all = "lowercase")]
26pub enum ProbeCategory {
27    /// Did specific facts survive? (file paths, function names, values, decisions)
28    Recall,
29    /// Does the agent know which files/tools/URLs it used?
30    Artifact,
31    /// Can it pick up mid-task? (current step, next steps, blockers, open questions)
32    Continuation,
33    /// Are past reasoning traces intact? (why X over Y, trade-offs, constraints)
34    Decision,
35}
36
37/// Per-category scoring breakdown from a compaction probe run.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct CategoryScore {
40    pub category: ProbeCategory,
41    /// Average score of all questions in this category, in [0.0, 1.0].
42    pub score: f32,
43    /// Number of questions generated for this category.
44    pub probes_run: u32,
45}
46
47/// A single factual question with the expected answer.
48#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
49pub struct ProbeQuestion {
50    /// Factual question about the compacted messages.
51    pub question: String,
52    /// Expected correct answer extractable from the original messages.
53    pub expected_answer: String,
54    /// Functional category of this question.
55    #[serde(default = "default_probe_category")]
56    pub category: ProbeCategory,
57}
58
59fn default_probe_category() -> ProbeCategory {
60    ProbeCategory::Recall
61}
62
63impl Default for ProbeQuestion {
64    fn default() -> Self {
65        Self {
66            question: String::new(),
67            expected_answer: String::new(),
68            category: ProbeCategory::Recall,
69        }
70    }
71}
72
73/// Three-tier verdict for compaction probe quality.
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
75pub enum ProbeVerdict {
76    /// Score >= `threshold`: summary preserves enough context. Proceed.
77    Pass,
78    /// Score in [`hard_fail_threshold`, `threshold`): summary is borderline.
79    /// Proceed with compaction but log a warning.
80    SoftFail,
81    /// Score < `hard_fail_threshold`: summary lost critical facts. Block compaction.
82    HardFail,
83    /// Transport/timeout failure — no quality score produced.
84    Error,
85}
86
87/// Full result of a compaction probe run.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct CompactionProbeResult {
90    /// Overall score in [0.0, 1.0].
91    pub score: f32,
92    /// Per-category breakdown. Categories with 0 questions are omitted.
93    #[serde(default)]
94    pub category_scores: Vec<CategoryScore>,
95    /// Per-question breakdown.
96    pub questions: Vec<ProbeQuestion>,
97    /// LLM answers to the questions (positionally aligned with `questions`).
98    pub answers: Vec<String>,
99    /// Per-question similarity scores.
100    pub per_question_scores: Vec<f32>,
101    pub verdict: ProbeVerdict,
102    /// Pass threshold used for this run.
103    pub threshold: f32,
104    /// Hard-fail threshold used for this run.
105    pub hard_fail_threshold: f32,
106    /// Model name used for the probe.
107    pub model: String,
108    /// Wall-clock duration in milliseconds.
109    pub duration_ms: u64,
110}
111
112// --- Structured LLM output types ---
113
114#[derive(Debug, Deserialize, JsonSchema)]
115struct ProbeQuestionsOutput {
116    questions: Vec<ProbeQuestion>,
117}
118
119// --- Category scoring ---
120
121/// Group per-question scores by category and compute per-category averages.
122///
123/// Categories with no questions are excluded from the returned list.
124#[must_use]
125fn compute_category_scores(
126    questions: &[ProbeQuestion],
127    per_question_scores: &[f32],
128    category_weights: Option<&HashMap<ProbeCategory, f32>>,
129) -> (Vec<CategoryScore>, f32) {
130    // Group scores by category.
131    let mut by_cat: HashMap<ProbeCategory, Vec<f32>> = HashMap::new();
132    for (q, &s) in questions.iter().zip(per_question_scores.iter()) {
133        by_cat.entry(q.category).or_default().push(s);
134    }
135
136    #[allow(clippy::cast_precision_loss)]
137    let category_scores: Vec<CategoryScore> = by_cat
138        .into_iter()
139        .map(|(category, scores)| {
140            let avg = scores.iter().sum::<f32>() / scores.len() as f32;
141            CategoryScore {
142                category,
143                score: avg,
144                #[allow(clippy::cast_possible_truncation)]
145                probes_run: scores.len() as u32,
146            }
147        })
148        .collect();
149
150    if category_scores.is_empty() {
151        return (category_scores, 0.0);
152    }
153
154    // Compute weighted overall score. Default: equal weights.
155    let mut weighted_sum = 0.0_f32;
156    let mut weight_total = 0.0_f32;
157    for cs in &category_scores {
158        let raw_w = category_weights
159            .and_then(|m| m.get(&cs.category).copied())
160            .unwrap_or(1.0);
161        if raw_w < 0.0 {
162            tracing::warn!(
163                category = ?cs.category,
164                weight = raw_w,
165                "category_weights contains a negative value — treating as 0.0 (category excluded from scoring)"
166            );
167        }
168        let w = raw_w.max(0.0);
169        weighted_sum += cs.score * w;
170        weight_total += w;
171    }
172
173    let overall = if weight_total > 0.0 {
174        weighted_sum / weight_total
175    } else {
176        // All weights are 0 — fall back to equal weighting.
177        #[allow(clippy::cast_precision_loss)]
178        let n = category_scores.len() as f32;
179        category_scores.iter().map(|cs| cs.score).sum::<f32>() / n
180    };
181
182    (category_scores, overall)
183}
184
185#[derive(Debug, Deserialize, JsonSchema)]
186struct ProbeAnswersOutput {
187    answers: Vec<String>,
188}
189
190// --- Scoring ---
191
192/// Refusal indicators: if the actual answer contains any of these, score it 0.0.
193const REFUSAL_PATTERNS: &[&str] = &[
194    "unknown",
195    "not mentioned",
196    "not found",
197    "n/a",
198    "cannot determine",
199    "no information",
200    "not provided",
201    "not specified",
202    "not stated",
203    "not available",
204];
205
206fn is_refusal(text: &str) -> bool {
207    let lower = text.to_lowercase();
208    REFUSAL_PATTERNS.iter().any(|p| lower.contains(p))
209}
210
211/// Normalize a string: lowercase, split on non-alphanumeric chars, keep tokens >= 3 chars.
212fn normalize_tokens(text: &str) -> Vec<String> {
213    text.to_lowercase()
214        .split(|c: char| !c.is_alphanumeric())
215        .filter(|t| t.len() >= 3)
216        .map(String::from)
217        .collect()
218}
219
220fn jaccard(a: &[String], b: &[String]) -> f32 {
221    if a.is_empty() && b.is_empty() {
222        return 1.0;
223    }
224    let set_a: std::collections::HashSet<&str> = a.iter().map(String::as_str).collect();
225    let set_b: std::collections::HashSet<&str> = b.iter().map(String::as_str).collect();
226    let intersection = set_a.intersection(&set_b).count();
227    let union = set_a.union(&set_b).count();
228    if union == 0 {
229        return 0.0;
230    }
231    #[allow(clippy::cast_precision_loss)]
232    {
233        intersection as f32 / union as f32
234    }
235}
236
237/// Score a single (expected, actual) answer pair using token-set-ratio.
238fn score_pair(expected: &str, actual: &str) -> f32 {
239    if is_refusal(actual) {
240        return 0.0;
241    }
242
243    let tokens_e = normalize_tokens(expected);
244    let tokens_a = normalize_tokens(actual);
245
246    // Substring boost: if all expected tokens appear in actual, it's an exact match.
247    if !tokens_e.is_empty() {
248        let set_e: std::collections::HashSet<&str> = tokens_e.iter().map(String::as_str).collect();
249        let set_a: std::collections::HashSet<&str> = tokens_a.iter().map(String::as_str).collect();
250        if set_e.is_subset(&set_a) {
251            return 1.0;
252        }
253    }
254
255    // Token-set-ratio: max of three Jaccard variants.
256    let j_full = jaccard(&tokens_e, &tokens_a);
257
258    // Intersection with each set individually (handles subset relationships).
259    let set_e: std::collections::HashSet<&str> = tokens_e.iter().map(String::as_str).collect();
260    let set_a: std::collections::HashSet<&str> = tokens_a.iter().map(String::as_str).collect();
261    let intersection: Vec<String> = set_e
262        .intersection(&set_a)
263        .map(|s| (*s).to_owned())
264        .collect();
265
266    #[allow(clippy::cast_precision_loss)]
267    let j_e = if tokens_e.is_empty() {
268        0.0_f32
269    } else {
270        intersection.len() as f32 / tokens_e.len() as f32
271    };
272    #[allow(clippy::cast_precision_loss)]
273    let j_a = if tokens_a.is_empty() {
274        0.0_f32
275    } else {
276        intersection.len() as f32 / tokens_a.len() as f32
277    };
278
279    j_full.max(j_e).max(j_a)
280}
281
282/// Score answers against expected values using token-set-ratio similarity.
283///
284/// Returns `(per_question_scores, overall_average)`.
285#[must_use]
286pub fn score_answers(questions: &[ProbeQuestion], answers: &[String]) -> (Vec<f32>, f32) {
287    if questions.is_empty() {
288        return (vec![], 0.0);
289    }
290    let scores: Vec<f32> = questions
291        .iter()
292        .zip(answers.iter().chain(std::iter::repeat(&String::new())))
293        .map(|(q, a)| score_pair(&q.expected_answer, a))
294        .collect();
295    #[allow(clippy::cast_precision_loss)]
296    let avg = if scores.is_empty() {
297        0.0
298    } else {
299        scores.iter().sum::<f32>() / scores.len() as f32
300    };
301    (scores, avg)
302}
303
304// --- LLM calls ---
305
306/// Truncate tool-result bodies to 500 chars to avoid flooding the probe with raw output.
307fn truncate_tool_bodies(messages: &[Message]) -> Vec<Message> {
308    messages
309        .iter()
310        .map(|m| {
311            let mut msg = m.clone();
312            for part in &mut msg.parts {
313                if let MessagePart::ToolOutput { body, .. } = part {
314                    if body.len() <= 500 {
315                        continue;
316                    }
317                    body.truncate(500);
318                    body.push('\u{2026}');
319                }
320            }
321            msg.rebuild_content();
322            msg
323        })
324        .collect()
325}
326
327/// Generate factual probe questions from the messages being compacted.
328///
329/// Uses a single LLM call with structured output. Tool-result bodies are
330/// truncated to 500 chars to focus on decisions and outcomes rather than raw tool output.
331///
332/// # Errors
333///
334/// Returns `MemoryError::Llm` if the LLM call fails.
335pub async fn generate_probe_questions(
336    provider: &AnyProvider,
337    messages: &[Message],
338    max_questions: usize,
339) -> Result<Vec<ProbeQuestion>, MemoryError> {
340    let truncated = truncate_tool_bodies(messages);
341
342    let mut history = String::new();
343    for msg in &truncated {
344        let role = match msg.role {
345            Role::User => "user",
346            Role::Assistant => "assistant",
347            Role::System => "system",
348        };
349        history.push_str(role);
350        history.push_str(": ");
351        history.push_str(&msg.content);
352        history.push('\n');
353    }
354
355    let prompt = format!(
356        "Given the following conversation excerpt, generate {max_questions} factual questions \
357         that test whether a summary preserves the most important concrete details.\n\
358         \n\
359         You MUST generate at least one question per category when max_questions >= 4. \
360         If the conversation lacks information for a category, generate a question noting that absence.\n\
361         \n\
362         Categories:\n\
363         - recall: Specific facts that survived (file paths, function names, values). \
364           Example: \"What file was modified?\"\n\
365         - artifact: Which files/tools/URLs the agent used. \
366           Example: \"Which tool was executed?\"\n\
367         - continuation: Next steps, blockers, open questions. \
368           Example: \"What is the next step?\"\n\
369         - decision: Past reasoning traces (why X over Y, trade-offs). \
370           Example: \"Why was X chosen over Y?\"\n\
371         \n\
372         Do NOT generate questions about:\n\
373         - Raw tool output content (compiler warnings, test output line numbers)\n\
374         - Intermediate debugging steps that were superseded\n\
375         - Opinions or reasoning that cannot be verified\n\
376         \n\
377         Each question must have a single unambiguous expected answer extractable from the text.\n\
378         \n\
379         Conversation:\n{history}\n\
380         \n\
381         Respond in JSON with schema: {{\"questions\": [{{\"question\": \"...\", \
382         \"expected_answer\": \"...\", \"category\": \"recall|artifact|continuation|decision\"}}]}}"
383    );
384
385    let msgs = [Message {
386        role: Role::User,
387        content: prompt,
388        parts: vec![],
389        metadata: MessageMetadata::default(),
390    }];
391
392    let mut output: ProbeQuestionsOutput = provider
393        .chat_typed_erased::<ProbeQuestionsOutput>(&msgs)
394        .await
395        .map_err(MemoryError::Llm)?;
396
397    // Cap the list to max_questions: a misbehaving LLM could return more.
398    output.questions.truncate(max_questions);
399
400    Ok(output.questions)
401}
402
403/// Answer probe questions using only the compaction summary as context.
404///
405/// # Errors
406///
407/// Returns `MemoryError::Llm` if the LLM call fails.
408pub async fn answer_probe_questions(
409    provider: &AnyProvider,
410    summary: &str,
411    questions: &[ProbeQuestion],
412) -> Result<Vec<String>, MemoryError> {
413    let mut numbered = String::new();
414    for (i, q) in questions.iter().enumerate() {
415        use std::fmt::Write as _;
416        let _ = writeln!(numbered, "{}. {}", i + 1, q.question);
417    }
418
419    let prompt = format!(
420        "Given the following summary of a conversation, answer each question using ONLY \
421         information present in the summary. If the answer is not in the summary, respond \
422         with \"UNKNOWN\".\n\
423         \n\
424         Summary:\n{summary}\n\
425         \n\
426         Questions:\n{numbered}\n\
427         \n\
428         Respond in JSON with schema: {{\"answers\": [\"answer1\", \"answer2\", ...]}}"
429    );
430
431    let msgs = [Message {
432        role: Role::User,
433        content: prompt,
434        parts: vec![],
435        metadata: MessageMetadata::default(),
436    }];
437
438    let output: ProbeAnswersOutput = provider
439        .chat_typed_erased::<ProbeAnswersOutput>(&msgs)
440        .await
441        .map_err(MemoryError::Llm)?;
442
443    Ok(output.answers)
444}
445
446/// Configuration for the compaction probe.
447#[derive(Debug, Clone, Serialize, Deserialize)]
448#[serde(default)]
449pub struct CompactionProbeConfig {
450    /// Enable compaction probe validation. Default: `false`.
451    pub enabled: bool,
452    /// Provider name from `[[llm.providers]]` for probe LLM calls.
453    /// Empty string = use the summary provider.
454    pub probe_provider: String,
455    /// Minimum score to pass without warnings. Default: `0.6`.
456    /// Scores in [`hard_fail_threshold`, `threshold`) trigger `SoftFail` (warn + proceed).
457    pub threshold: f32,
458    /// Score below this triggers `HardFail` (block compaction). Default: `0.35`.
459    pub hard_fail_threshold: f32,
460    /// Maximum number of probe questions to generate. Default: `5`.
461    pub max_questions: usize,
462    /// Timeout for the entire probe (both LLM calls) in seconds. Default: `15`.
463    pub timeout_secs: u64,
464    /// Optional per-category weight multipliers for the overall score.
465    /// When `None` or empty, all categories are weighted equally.
466    /// Example: `{ recall = 1.5, artifact = 1.0, continuation = 1.0, decision = 0.8 }`
467    #[serde(default)]
468    pub category_weights: Option<HashMap<ProbeCategory, f32>>,
469}
470
471impl Default for CompactionProbeConfig {
472    fn default() -> Self {
473        Self {
474            enabled: false,
475            probe_provider: String::new(),
476            threshold: 0.6,
477            hard_fail_threshold: 0.35,
478            max_questions: 5,
479            timeout_secs: 15,
480            category_weights: None,
481        }
482    }
483}
484
485/// Run the compaction probe: generate questions, answer them from the summary, score results.
486///
487/// Returns `Ok(None)` when:
488/// - Probe is disabled (`config.enabled = false`)
489/// - The probe times out
490/// - Fewer than 2 questions are generated (insufficient statistical power)
491///
492/// The caller treats `None` as "no opinion" and proceeds with compaction.
493///
494/// # Errors
495///
496/// Returns `MemoryError` if an LLM call fails. Callers should treat this as non-fatal
497/// and proceed with compaction.
498pub async fn validate_compaction(
499    provider: &AnyProvider,
500    messages: &[Message],
501    summary: &str,
502    config: &CompactionProbeConfig,
503) -> Result<Option<CompactionProbeResult>, MemoryError> {
504    if !config.enabled {
505        return Ok(None);
506    }
507
508    let timeout = std::time::Duration::from_secs(config.timeout_secs);
509    let start = Instant::now();
510
511    let result = tokio::time::timeout(timeout, async {
512        run_probe(provider, messages, summary, config).await
513    })
514    .await;
515
516    match result {
517        Ok(inner) => inner,
518        Err(_elapsed) => {
519            tracing::warn!(
520                timeout_secs = config.timeout_secs,
521                "compaction probe timed out — proceeding with compaction"
522            );
523            Ok(None)
524        }
525    }
526    .map(|opt| {
527        opt.map(|mut r| {
528            r.duration_ms = u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX);
529            r
530        })
531    })
532}
533
534async fn run_probe(
535    provider: &AnyProvider,
536    messages: &[Message],
537    summary: &str,
538    config: &CompactionProbeConfig,
539) -> Result<Option<CompactionProbeResult>, MemoryError> {
540    if summary.len() < 10 {
541        tracing::warn!(
542            len = summary.len(),
543            "compaction probe: summary too short — skipping probe"
544        );
545        return Ok(None);
546    }
547
548    let questions = generate_probe_questions(provider, messages, config.max_questions).await?;
549
550    if questions.len() < 2 {
551        tracing::debug!(
552            count = questions.len(),
553            "compaction probe: fewer than 2 questions generated — skipping probe"
554        );
555        return Ok(None);
556    }
557
558    // Warn if any category is missing when we expected full coverage.
559    if config.max_questions >= 4 {
560        use std::collections::HashSet;
561        let covered: HashSet<_> = questions.iter().map(|q| q.category).collect();
562        for cat in [
563            ProbeCategory::Recall,
564            ProbeCategory::Artifact,
565            ProbeCategory::Continuation,
566            ProbeCategory::Decision,
567        ] {
568            if !covered.contains(&cat) {
569                tracing::warn!(
570                    category = ?cat,
571                    "compaction probe: LLM did not generate questions for category"
572                );
573            }
574        }
575    }
576
577    let answers = answer_probe_questions(provider, summary, &questions).await?;
578
579    let (per_question_scores, _simple_avg) = score_answers(&questions, &answers);
580
581    let (category_scores, score) = compute_category_scores(
582        &questions,
583        &per_question_scores,
584        config.category_weights.as_ref(),
585    );
586
587    let verdict = if score >= config.threshold {
588        ProbeVerdict::Pass
589    } else if score >= config.hard_fail_threshold {
590        ProbeVerdict::SoftFail
591    } else {
592        ProbeVerdict::HardFail
593    };
594
595    let model = provider.name().to_owned();
596
597    Ok(Some(CompactionProbeResult {
598        score,
599        category_scores,
600        questions,
601        answers,
602        per_question_scores,
603        verdict,
604        threshold: config.threshold,
605        hard_fail_threshold: config.hard_fail_threshold,
606        model,
607        duration_ms: 0, // filled in by validate_compaction
608    }))
609}
610
611#[cfg(test)]
612mod tests {
613    use super::*;
614
615    // --- score_answers tests ---
616
617    #[test]
618    fn score_perfect_match() {
619        let q = vec![ProbeQuestion {
620            question: "What crate is used?".into(),
621            expected_answer: "thiserror".into(),
622            category: ProbeCategory::Recall,
623        }];
624        let a = vec!["thiserror".into()];
625        let (scores, avg) = score_answers(&q, &a);
626        assert_eq!(scores.len(), 1);
627        assert!((avg - 1.0).abs() < 0.01, "expected ~1.0, got {avg}");
628    }
629
630    #[test]
631    fn score_complete_mismatch() {
632        let q = vec![ProbeQuestion {
633            question: "What file was modified?".into(),
634            expected_answer: "src/auth.rs".into(),
635            ..Default::default()
636        }];
637        let a = vec!["definitely not in the summary".into()];
638        let (scores, avg) = score_answers(&q, &a);
639        assert_eq!(scores.len(), 1);
640        // Very low overlap expected.
641        assert!(avg < 0.5, "expected low score, got {avg}");
642    }
643
644    #[test]
645    fn score_refusal_is_zero() {
646        let q = vec![ProbeQuestion {
647            question: "What was the decision?".into(),
648            expected_answer: "Use thiserror for typed errors".into(),
649            ..Default::default()
650        }];
651        for refusal in &[
652            "UNKNOWN",
653            "not mentioned",
654            "N/A",
655            "cannot determine",
656            "No information",
657        ] {
658            let a = vec![(*refusal).to_owned()];
659            let (_, avg) = score_answers(&q, &a);
660            assert!(avg < 0.01, "expected 0 for refusal '{refusal}', got {avg}");
661        }
662    }
663
664    #[test]
665    fn score_paraphrased_answer_above_half() {
666        // "thiserror was chosen for error types" vs "Use thiserror for typed errors"
667        // Shared tokens: "thiserror", "error" (and maybe "for"/"types"/"typed" with >=3 chars)
668        let q = vec![ProbeQuestion {
669            question: "What error handling crate was chosen?".into(),
670            expected_answer: "Use thiserror for typed errors in library crates".into(),
671            ..Default::default()
672        }];
673        let a = vec!["thiserror was chosen for error types in library crates".into()];
674        let (_, avg) = score_answers(&q, &a);
675        assert!(avg > 0.5, "expected >0.5 for paraphrase, got {avg}");
676    }
677
678    #[test]
679    fn score_empty_strings() {
680        let q = vec![ProbeQuestion {
681            question: "What?".into(),
682            expected_answer: String::new(),
683            ..Default::default()
684        }];
685        let a = vec![String::new()];
686        let (scores, avg) = score_answers(&q, &a);
687        assert_eq!(scores.len(), 1);
688        // Both empty — jaccard of two empty sets returns 1.0 (exact match).
689        assert!(
690            (avg - 1.0).abs() < 0.01,
691            "expected 1.0 for empty vs empty, got {avg}"
692        );
693    }
694
695    #[test]
696    fn score_empty_questions_list() {
697        let (scores, avg) = score_answers(&[], &[]);
698        assert!(scores.is_empty());
699        assert!((avg - 0.0).abs() < 0.01);
700    }
701
702    #[test]
703    fn score_file_path_exact() {
704        let q = vec![ProbeQuestion {
705            question: "Which file was modified?".into(),
706            expected_answer: "crates/zeph-memory/src/compaction_probe.rs".into(),
707            ..Default::default()
708        }];
709        let a = vec!["The file crates/zeph-memory/src/compaction_probe.rs was modified.".into()];
710        let (_, avg) = score_answers(&q, &a);
711        // Substring boost should fire: all expected tokens present in actual.
712        assert!(
713            avg > 0.8,
714            "expected high score for file path match, got {avg}"
715        );
716    }
717
718    #[test]
719    fn score_unicode_input() {
720        let q = vec![ProbeQuestion {
721            question: "Что было изменено?".into(),
722            expected_answer: "файл config.toml".into(),
723            ..Default::default()
724        }];
725        let a = vec!["config.toml был изменён".into()];
726        // Just verify no panic; score may vary.
727        let (scores, _) = score_answers(&q, &a);
728        assert_eq!(scores.len(), 1);
729    }
730
731    // --- verdict threshold tests ---
732
733    #[test]
734    fn verdict_thresholds() {
735        let config = CompactionProbeConfig::default();
736
737        // Pass >= 0.6
738        let score = 0.7_f32;
739        let verdict = if score >= config.threshold {
740            ProbeVerdict::Pass
741        } else if score >= config.hard_fail_threshold {
742            ProbeVerdict::SoftFail
743        } else {
744            ProbeVerdict::HardFail
745        };
746        assert_eq!(verdict, ProbeVerdict::Pass);
747
748        // SoftFail [0.35, 0.6)
749        let score = 0.5_f32;
750        let verdict = if score >= config.threshold {
751            ProbeVerdict::Pass
752        } else if score >= config.hard_fail_threshold {
753            ProbeVerdict::SoftFail
754        } else {
755            ProbeVerdict::HardFail
756        };
757        assert_eq!(verdict, ProbeVerdict::SoftFail);
758
759        // HardFail < 0.35
760        let score = 0.2_f32;
761        let verdict = if score >= config.threshold {
762            ProbeVerdict::Pass
763        } else if score >= config.hard_fail_threshold {
764            ProbeVerdict::SoftFail
765        } else {
766            ProbeVerdict::HardFail
767        };
768        assert_eq!(verdict, ProbeVerdict::HardFail);
769    }
770
771    // --- config defaults ---
772
773    #[test]
774    fn config_defaults() {
775        let c = CompactionProbeConfig::default();
776        assert!(!c.enabled);
777        assert!(c.probe_provider.is_empty());
778        assert!((c.threshold - 0.6).abs() < 0.001);
779        assert!((c.hard_fail_threshold - 0.35).abs() < 0.001);
780        assert_eq!(c.max_questions, 5);
781        assert_eq!(c.timeout_secs, 15);
782        assert!(c.category_weights.is_none());
783    }
784
785    // --- serde round-trip ---
786
787    #[test]
788    fn config_serde_round_trip() {
789        let original = CompactionProbeConfig {
790            enabled: true,
791            probe_provider: "fast".into(),
792            threshold: 0.65,
793            hard_fail_threshold: 0.4,
794            max_questions: 5,
795            timeout_secs: 20,
796            category_weights: None,
797        };
798        let json = serde_json::to_string(&original).expect("serialize");
799        let restored: CompactionProbeConfig = serde_json::from_str(&json).expect("deserialize");
800        assert!(restored.enabled);
801        assert_eq!(restored.probe_provider, "fast");
802        assert!((restored.threshold - 0.65).abs() < 0.001);
803    }
804
805    #[test]
806    fn probe_result_serde_round_trip() {
807        let result = CompactionProbeResult {
808            score: 0.75,
809            category_scores: vec![CategoryScore {
810                category: ProbeCategory::Recall,
811                score: 0.75,
812                probes_run: 1,
813            }],
814            questions: vec![ProbeQuestion {
815                question: "What?".into(),
816                expected_answer: "thiserror".into(),
817                category: ProbeCategory::Recall,
818            }],
819            answers: vec!["thiserror".into()],
820            per_question_scores: vec![1.0],
821            verdict: ProbeVerdict::Pass,
822            threshold: 0.6,
823            hard_fail_threshold: 0.35,
824            model: "haiku".into(),
825            duration_ms: 1234,
826        };
827        let json = serde_json::to_string(&result).expect("serialize");
828        let restored: CompactionProbeResult = serde_json::from_str(&json).expect("deserialize");
829        assert!((restored.score - 0.75).abs() < 0.001);
830        assert_eq!(restored.verdict, ProbeVerdict::Pass);
831        assert_eq!(restored.category_scores.len(), 1);
832    }
833
834    #[test]
835    fn probe_result_backward_compat_no_category_scores() {
836        // Old JSON without category_scores field must deserialize with empty vec.
837        let json = r#"{"score":0.75,"questions":[],"answers":[],"per_question_scores":[],"verdict":"Pass","threshold":0.6,"hard_fail_threshold":0.35,"model":"haiku","duration_ms":0}"#;
838        let restored: CompactionProbeResult = serde_json::from_str(json).expect("deserialize");
839        assert!(restored.category_scores.is_empty());
840    }
841
842    // --- fewer answers than questions (LLM returned truncated list) ---
843
844    #[test]
845    fn score_fewer_answers_than_questions() {
846        let questions = vec![
847            ProbeQuestion {
848                question: "What crate?".into(),
849                expected_answer: "thiserror".into(),
850                ..Default::default()
851            },
852            ProbeQuestion {
853                question: "What file?".into(),
854                expected_answer: "src/lib.rs".into(),
855                ..Default::default()
856            },
857            ProbeQuestion {
858                question: "What decision?".into(),
859                expected_answer: "use async traits".into(),
860                ..Default::default()
861            },
862        ];
863        // LLM only returned 1 answer for 3 questions.
864        let answers = vec!["thiserror".into()];
865        let (scores, avg) = score_answers(&questions, &answers);
866        // scores must have the same length as questions (missing answers → empty string → ~0).
867        assert_eq!(scores.len(), 3);
868        // First answer is a perfect match.
869        assert!(
870            (scores[0] - 1.0).abs() < 0.01,
871            "first score should be ~1.0, got {}",
872            scores[0]
873        );
874        // Missing answers score 0 (empty string vs non-empty expected).
875        assert!(
876            scores[1] < 0.5,
877            "second score should be low for missing answer, got {}",
878            scores[1]
879        );
880        assert!(
881            scores[2] < 0.5,
882            "third score should be low for missing answer, got {}",
883            scores[2]
884        );
885        // Average is dragged down by the two missing answers.
886        assert!(
887            avg < 0.5,
888            "average should be below 0.5 with 2 missing answers, got {avg}"
889        );
890    }
891
892    // --- exact boundary values for threshold ---
893
894    #[test]
895    fn verdict_boundary_at_threshold() {
896        let config = CompactionProbeConfig::default();
897
898        // Exactly at pass threshold → Pass.
899        let score = config.threshold;
900        let verdict = if score >= config.threshold {
901            ProbeVerdict::Pass
902        } else if score >= config.hard_fail_threshold {
903            ProbeVerdict::SoftFail
904        } else {
905            ProbeVerdict::HardFail
906        };
907        assert_eq!(verdict, ProbeVerdict::Pass);
908
909        // One ULP below pass threshold, above hard-fail → SoftFail.
910        let score = config.threshold - f32::EPSILON;
911        let verdict = if score >= config.threshold {
912            ProbeVerdict::Pass
913        } else if score >= config.hard_fail_threshold {
914            ProbeVerdict::SoftFail
915        } else {
916            ProbeVerdict::HardFail
917        };
918        assert_eq!(verdict, ProbeVerdict::SoftFail);
919
920        // Exactly at hard-fail threshold → SoftFail (boundary is inclusive).
921        let score = config.hard_fail_threshold;
922        let verdict = if score >= config.threshold {
923            ProbeVerdict::Pass
924        } else if score >= config.hard_fail_threshold {
925            ProbeVerdict::SoftFail
926        } else {
927            ProbeVerdict::HardFail
928        };
929        assert_eq!(verdict, ProbeVerdict::SoftFail);
930
931        // One ULP below hard-fail threshold → HardFail.
932        let score = config.hard_fail_threshold - f32::EPSILON;
933        let verdict = if score >= config.threshold {
934            ProbeVerdict::Pass
935        } else if score >= config.hard_fail_threshold {
936            ProbeVerdict::SoftFail
937        } else {
938            ProbeVerdict::HardFail
939        };
940        assert_eq!(verdict, ProbeVerdict::HardFail);
941    }
942
943    // --- config partial deserialization (serde default fields) ---
944
945    #[test]
946    fn config_partial_json_uses_defaults() {
947        // Only `enabled` is specified; all other fields must fall back to defaults via #[serde(default)].
948        let json = r#"{"enabled": true}"#;
949        let c: CompactionProbeConfig =
950            serde_json::from_str(json).expect("deserialize partial json");
951        assert!(c.enabled);
952        assert!(c.probe_provider.is_empty());
953        assert!((c.threshold - 0.6).abs() < 0.001);
954        assert!((c.hard_fail_threshold - 0.35).abs() < 0.001);
955        assert_eq!(c.max_questions, 5);
956        assert_eq!(c.timeout_secs, 15);
957    }
958
959    #[test]
960    fn config_empty_json_uses_all_defaults() {
961        let c: CompactionProbeConfig = serde_json::from_str("{}").expect("deserialize empty json");
962        assert!(!c.enabled);
963        assert!(c.probe_provider.is_empty());
964    }
965
966    // --- ProbeCategory serde ---
967
968    #[test]
969    fn probe_category_serde_lowercase() {
970        assert_eq!(
971            serde_json::to_string(&ProbeCategory::Recall).unwrap(),
972            r#""recall""#
973        );
974        assert_eq!(
975            serde_json::to_string(&ProbeCategory::Artifact).unwrap(),
976            r#""artifact""#
977        );
978        assert_eq!(
979            serde_json::to_string(&ProbeCategory::Continuation).unwrap(),
980            r#""continuation""#
981        );
982        assert_eq!(
983            serde_json::to_string(&ProbeCategory::Decision).unwrap(),
984            r#""decision""#
985        );
986        let cat: ProbeCategory = serde_json::from_str(r#""recall""#).unwrap();
987        assert_eq!(cat, ProbeCategory::Recall);
988    }
989
990    // --- category_weights TOML compat ---
991
992    #[test]
993    fn category_weights_toml_round_trip() {
994        let toml_str = r#"
995enabled = true
996probe_provider = "fast"
997threshold = 0.6
998hard_fail_threshold = 0.35
999max_questions = 5
1000timeout_secs = 15
1001[category_weights]
1002recall = 1.5
1003artifact = 1.0
1004continuation = 1.0
1005decision = 0.8
1006"#;
1007        let c: CompactionProbeConfig = toml::from_str(toml_str).expect("deserialize toml");
1008        let weights = c.category_weights.as_ref().unwrap();
1009        assert!((weights[&ProbeCategory::Recall] - 1.5).abs() < 0.001);
1010        assert!((weights[&ProbeCategory::Decision] - 0.8).abs() < 0.001);
1011    }
1012
1013    // --- compute_category_scores ---
1014
1015    #[test]
1016    fn category_scores_equal_weights() {
1017        let questions = vec![
1018            ProbeQuestion {
1019                question: "Q1".into(),
1020                expected_answer: "A1".into(),
1021                category: ProbeCategory::Recall,
1022            },
1023            ProbeQuestion {
1024                question: "Q2".into(),
1025                expected_answer: "A2".into(),
1026                category: ProbeCategory::Artifact,
1027            },
1028        ];
1029        let scores = [1.0_f32, 0.0_f32];
1030        let (cats, overall) = compute_category_scores(&questions, &scores, None);
1031        assert_eq!(cats.len(), 2);
1032        // Equal weight: (1.0 + 0.0) / 2 = 0.5
1033        assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
1034    }
1035
1036    #[test]
1037    fn category_scores_missing_category_excluded() {
1038        // Only Recall and Decision present; Artifact/Continuation absent.
1039        let questions = vec![
1040            ProbeQuestion {
1041                question: "Q1".into(),
1042                expected_answer: "A1".into(),
1043                category: ProbeCategory::Recall,
1044            },
1045            ProbeQuestion {
1046                question: "Q2".into(),
1047                expected_answer: "A2".into(),
1048                category: ProbeCategory::Decision,
1049            },
1050        ];
1051        let scores = [1.0_f32, 0.6_f32];
1052        let (cats, _overall) = compute_category_scores(&questions, &scores, None);
1053        assert_eq!(cats.len(), 2, "only categories with questions present");
1054        let categories: Vec<_> = cats.iter().map(|c| c.category).collect();
1055        assert!(!categories.contains(&ProbeCategory::Artifact));
1056        assert!(!categories.contains(&ProbeCategory::Continuation));
1057    }
1058
1059    #[test]
1060    fn category_scores_custom_weights() {
1061        let questions = vec![
1062            ProbeQuestion {
1063                question: "Q1".into(),
1064                expected_answer: "A1".into(),
1065                category: ProbeCategory::Recall,
1066            },
1067            ProbeQuestion {
1068                question: "Q2".into(),
1069                expected_answer: "A2".into(),
1070                category: ProbeCategory::Decision,
1071            },
1072        ];
1073        let scores = [1.0_f32, 0.0_f32];
1074        let mut weights = HashMap::new();
1075        weights.insert(ProbeCategory::Recall, 2.0_f32);
1076        weights.insert(ProbeCategory::Decision, 1.0_f32);
1077        let (_, overall) = compute_category_scores(&questions, &scores, Some(&weights));
1078        // (1.0*2 + 0.0*1) / (2+1) = 0.666..
1079        assert!(
1080            (overall - 2.0 / 3.0).abs() < 0.001,
1081            "expected ~0.667, got {overall}"
1082        );
1083    }
1084
1085    #[test]
1086    fn category_scores_all_zero_weights_fallback() {
1087        let questions = vec![
1088            ProbeQuestion {
1089                question: "Q1".into(),
1090                expected_answer: "A1".into(),
1091                category: ProbeCategory::Recall,
1092            },
1093            ProbeQuestion {
1094                question: "Q2".into(),
1095                expected_answer: "A2".into(),
1096                category: ProbeCategory::Artifact,
1097            },
1098        ];
1099        let scores = [1.0_f32, 0.0_f32];
1100        let mut weights = HashMap::new();
1101        weights.insert(ProbeCategory::Recall, 0.0_f32);
1102        weights.insert(ProbeCategory::Artifact, 0.0_f32);
1103        let (_, overall) = compute_category_scores(&questions, &scores, Some(&weights));
1104        // Fallback to equal weighting: (1.0 + 0.0) / 2 = 0.5
1105        assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
1106    }
1107
1108    #[test]
1109    fn category_scores_empty_questions() {
1110        let (cats, overall) = compute_category_scores(&[], &[], None);
1111        assert!(cats.is_empty());
1112        assert!((overall - 0.0).abs() < 0.001);
1113    }
1114
1115    #[test]
1116    fn category_scores_multi_probe_single_category_averages() {
1117        // Three Recall questions with scores 1.0, 0.0, 0.5 → average 0.5.
1118        let questions = vec![
1119            ProbeQuestion {
1120                question: "Q1".into(),
1121                expected_answer: "A1".into(),
1122                category: ProbeCategory::Recall,
1123            },
1124            ProbeQuestion {
1125                question: "Q2".into(),
1126                expected_answer: "A2".into(),
1127                category: ProbeCategory::Recall,
1128            },
1129            ProbeQuestion {
1130                question: "Q3".into(),
1131                expected_answer: "A3".into(),
1132                category: ProbeCategory::Recall,
1133            },
1134        ];
1135        let scores = [1.0_f32, 0.0_f32, 0.5_f32];
1136        let (cats, overall) = compute_category_scores(&questions, &scores, None);
1137        assert_eq!(cats.len(), 1, "only one category present");
1138        assert_eq!(cats[0].category, ProbeCategory::Recall);
1139        assert_eq!(cats[0].probes_run, 3);
1140        assert!(
1141            (cats[0].score - 0.5).abs() < 0.001,
1142            "cat score={}",
1143            cats[0].score
1144        );
1145        // With one category, overall equals that category's average.
1146        assert!((overall - 0.5).abs() < 0.001, "overall={overall}");
1147    }
1148
1149    #[test]
1150    fn probe_question_serde_default_category() {
1151        // Old JSON without `category` field must deserialize to Recall via #[serde(default)].
1152        let json = r#"{"question":"What file?","expected_answer":"src/lib.rs"}"#;
1153        let q: ProbeQuestion = serde_json::from_str(json).expect("deserialize");
1154        assert_eq!(q.category, ProbeCategory::Recall);
1155        assert_eq!(q.question, "What file?");
1156        assert_eq!(q.expected_answer, "src/lib.rs");
1157    }
1158
1159    #[test]
1160    fn probe_question_serde_all_categories_round_trip() {
1161        for cat in [
1162            ProbeCategory::Recall,
1163            ProbeCategory::Artifact,
1164            ProbeCategory::Continuation,
1165            ProbeCategory::Decision,
1166        ] {
1167            let q = ProbeQuestion {
1168                question: "test?".into(),
1169                expected_answer: "answer".into(),
1170                category: cat,
1171            };
1172            let json = serde_json::to_string(&q).expect("serialize");
1173            let restored: ProbeQuestion = serde_json::from_str(&json).expect("deserialize");
1174            assert_eq!(restored.category, cat);
1175        }
1176    }
1177}