Skip to main content

tt_plan_core/
quality.rs

1//! Tier 3 LLM-judge quality scoring for Plan projections.
2//!
3//! The judge call is the most expensive operation in the Plan pipeline
4//! ($0.05–$0.50/scan per spec §18 budget). This module owns the stratified
5//! sampling, judge dispatch, and risk-band aggregation. Production wires it
6//! to a real LLM provider; tests drive it with a deterministic [`MockJudge`].
7//!
8//! # Hard invariants
9//!
10//! 1. **Opt-in only**: quality sampling reads full request bodies (per
11//!    ADR-009 / spec §11). [`score_quality`] refuses to run when
12//!    [`QualityConfig::body_logging_enabled`] is `false`.
13//! 2. **Stratified sampling**: by `(tag, size_bucket)` — bucket on
14//!    `input_tokens`. Proportional allocation, capped at total budget.
15//! 3. **Deterministic**: same `(requests, config, seed)` → bit-identical
16//!    risk band + sampled request IDs.
17//! 4. **Judge agnostic**: scoring uses any [`JudgeProvider`] impl.
18//! 5. **Risk thresholds** (per spec §7.4): `LOW` if `degraded ≤ 5%`,
19//!    `MEDIUM` if `5% < degraded ≤ 15%`, `HIGH` if `degraded > 15%`.
20//! 6. **Cost-capped**: refuses to dispatch when projected cost exceeds
21//!    [`QualityConfig::budget_usd`].
22//!
23//! # Scope
24//!
25//! Re-running the proposed model is *out of scope* for [`score_quality`] —
26//! the caller provides a `proposed_response_for(&Uuid) -> Option<String>`
27//! closure so production can dispatch to the real provider while tests can
28//! supply a canned map. Library + judge contract; dispatch is caller-owned.
29
30use std::collections::HashMap;
31
32use async_trait::async_trait;
33use rand::Rng;
34use rand::SeedableRng;
35use rand_chacha::ChaCha8Rng;
36use serde::{Deserialize, Serialize};
37use thiserror::Error;
38use uuid::Uuid;
39
40use crate::types::RequestLog;
41
42/// Aggregate risk classification per `docs/03-plan-replay-design.md` §7.4
43/// (task-spec thresholds: `≤5%` / `(5%, 15%]` / `>15%`).
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
45#[serde(rename_all = "lowercase")]
46pub enum RiskBand {
47    /// Degraded share of *classified* (non-Unclear) samples ≤ 5%.
48    Low,
49    /// Degraded share in `(5%, 15%]`.
50    Medium,
51    /// Degraded share > 15%.
52    High,
53}
54
55impl RiskBand {
56    /// Map a degraded percentage (0–100) to a band. Boundary policy:
57    /// `≤ 5%` → `Low`, `≤ 15%` → `Medium`, otherwise `High`.
58    #[must_use]
59    pub fn from_degraded_pct(p: f64) -> Self {
60        if p <= 5.0 {
61            Self::Low
62        } else if p <= 15.0 {
63            Self::Medium
64        } else {
65            Self::High
66        }
67    }
68}
69
70/// Per-sample judge verdict.
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
72#[serde(rename_all = "lowercase")]
73pub enum JudgeVerdict {
74    /// Proposed response is interchangeable with the original.
75    Acceptable,
76    /// Proposed response is materially worse than the original.
77    Degraded,
78    /// Judge declined to classify (model refusal, parse failure, etc.).
79    /// Counted toward neither acceptable nor degraded but recorded in the
80    /// total. A high `Unclear` share surfaces as a user-visible caveat.
81    Unclear,
82}
83
84/// One sampled request's score.
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct SampleScore {
87    /// Stable identifier of the source [`RequestLog`].
88    pub request_id: Uuid,
89    /// The judge's classification.
90    pub verdict: JudgeVerdict,
91    /// Best-effort one-line reason from the judge. Trimmed to 200 chars.
92    pub reason: String,
93}
94
95/// Aggregated quality result attached to a `PlanResult`.
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct QualityResult {
98    /// Number of requests the judge actually scored.
99    pub sample_size: u32,
100    /// Count of `Acceptable` verdicts.
101    pub acceptable_count: u32,
102    /// Count of `Degraded` verdicts.
103    pub degraded_count: u32,
104    /// Count of `Unclear` verdicts.
105    pub unclear_count: u32,
106    /// `degraded_count / (acceptable_count + degraded_count) × 100` (0–100).
107    /// Defined over *classified* samples only — `Unclear` is excluded from
108    /// the denominator because by definition we don't know its valence.
109    pub degraded_pct: f64,
110    /// Aggregate band — feeds the user-facing red/yellow/green pill.
111    pub risk_band: RiskBand,
112    /// Per-sample scores, in stable order. Bounded by
113    /// [`QualityConfig::total_samples`].
114    pub sampled_examples: Vec<SampleScore>,
115    /// Human-readable warnings (small sample, high `Unclear` share, etc.).
116    pub caveats: Vec<String>,
117}
118
119/// Configuration for one Tier 3 quality scoring run.
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct QualityConfig {
122    /// Required gate: caller must surface body-logging opt-in to the org.
123    /// `false` causes [`score_quality`] to error with
124    /// [`QualityError::BodyLoggingDisabled`].
125    pub body_logging_enabled: bool,
126    /// Total samples to draw across all strata (cap).
127    pub total_samples: u32,
128    /// Hard cost ceiling for judge calls in this run (USD).
129    pub budget_usd: f64,
130    /// Estimated USD per judge call (typically $0.001–$0.01 for a Sonnet
131    /// judge). Multiplied by `total_samples` for the up-front budget check.
132    pub cost_per_judge_call_usd: f64,
133    /// Random seed for stratified sampling. Determinism contract.
134    pub seed: u64,
135}
136
137/// Errors surfaced by [`score_quality`].
138#[derive(Debug, Error)]
139pub enum QualityError {
140    /// Caller invoked scoring without the body-logging opt-in. Tier 3
141    /// requires raw prompts + responses on each sampled row.
142    #[error("body logging not opted in by org — Tier 3 quality scoring requires raw bodies")]
143    BodyLoggingDisabled,
144
145    /// Pre-flight estimate `cost_per_judge_call_usd × total_samples`
146    /// exceeded the caller's budget. Scoring did not dispatch.
147    #[error("projected judge cost ${cost:.4} exceeds budget ${budget:.4}")]
148    OverBudget {
149        /// Projected cost in USD.
150        cost: f64,
151        /// Configured budget ceiling in USD.
152        budget: f64,
153    },
154
155    /// The [`JudgeProvider`] failed mid-run. Holds the provider's message.
156    #[error("judge: {0}")]
157    Judge(String),
158
159    /// Every sampled row was missing either the prompt body or the
160    /// historical response body — nothing was sent to the judge.
161    #[error("no sampled requests carry both prompt + response bodies")]
162    NoScorable,
163}
164
165/// Pluggable judge backend. Production: an LLM provider call. Tests:
166/// [`MockJudge`].
167#[async_trait]
168pub trait JudgeProvider: Send + Sync {
169    /// Compare an "original" response and a "proposed" response for the same
170    /// input. Return a verdict + one-line reason.
171    ///
172    /// # Errors
173    ///
174    /// Returns [`QualityError::Judge`] when the underlying provider fails
175    /// in a way the caller should surface to the user (e.g. auth failure,
176    /// rate-limit exhaustion). Implementations that recover internally via
177    /// retry should return [`JudgeVerdict::Unclear`] rather than erroring.
178    async fn judge(
179        &self,
180        input_body: &str,
181        original_response: &str,
182        proposed_response: &str,
183    ) -> Result<(JudgeVerdict, String), QualityError>;
184}
185
186/// Deterministic mock — used by replay tests and any caller that wants to
187/// stub the judge for offline analysis.
188pub struct MockJudge {
189    /// Force every call to return this verdict.
190    pub verdict: JudgeVerdict,
191    /// Reason string echoed verbatim on every call.
192    pub reason: String,
193}
194
195#[async_trait]
196impl JudgeProvider for MockJudge {
197    async fn judge(
198        &self,
199        _input: &str,
200        _orig: &str,
201        _prop: &str,
202    ) -> Result<(JudgeVerdict, String), QualityError> {
203        Ok((self.verdict, self.reason.clone()))
204    }
205}
206
207/// Compute the `(tag, size_bucket)` stratum for a request. Bucket boundaries
208/// match `docs/03-plan-replay-design.md` §7.1 (`small ≤ 500`, `medium ≤ 4000`,
209/// `large > 4000` input tokens).
210fn stratify(req: &RequestLog) -> (Option<String>, &'static str) {
211    let bucket = match req.input_tokens {
212        0..=500 => "small",
213        501..=4000 => "medium",
214        _ => "large",
215    };
216    (req.tag.clone(), bucket)
217}
218
219/// Draw `n` requests stratified by `(tag, size_bucket)` proportional to the
220/// stratum's share of the input population. Deterministic given `seed` and
221/// the input slice order (we re-sort by `id` internally so callers don't
222/// have to).
223///
224/// Returns the sampled IDs in ascending order. When `n == 0` or `requests`
225/// is empty, returns an empty `Vec`. Rounding can produce ≤ `n` samples
226/// (never more — final truncation enforces the cap).
227#[must_use]
228pub fn stratified_sample(requests: &[RequestLog], n: u32, seed: u64) -> Vec<Uuid> {
229    if n == 0 || requests.is_empty() {
230        return Vec::new();
231    }
232
233    // Index requests into deterministic strata. Sort the input by id first
234    // so the per-stratum vectors are populated in a deterministic order
235    // even when callers pass arbitrary orderings.
236    let mut sorted: Vec<&RequestLog> = requests.iter().collect();
237    sorted.sort_by_key(|r| r.id);
238
239    let mut by_stratum: HashMap<(Option<String>, &'static str), Vec<Uuid>> = HashMap::new();
240    for r in &sorted {
241        by_stratum.entry(stratify(r)).or_default().push(r.id);
242    }
243
244    // Proportional allocation: each stratum gets `n × (stratum_size / total)`
245    // (rounded). Iterate strata in sorted key order so the RNG draws happen
246    // in a deterministic sequence.
247    let total = requests.len() as f64;
248    let n_f = f64::from(n);
249    let mut keys: Vec<_> = by_stratum.keys().cloned().collect();
250    keys.sort();
251
252    let mut rng = ChaCha8Rng::seed_from_u64(seed);
253    let mut out = Vec::new();
254    for k in keys {
255        let stratum = &by_stratum[&k];
256        let alloc = ((stratum.len() as f64 / total) * n_f).round() as usize;
257        let alloc = alloc.min(stratum.len());
258        if alloc == 0 {
259            continue;
260        }
261        // Fisher–Yates partial shuffle so the first `alloc` indices are a
262        // uniform without-replacement draw.
263        let mut idx: Vec<usize> = (0..stratum.len()).collect();
264        for i in (1..idx.len()).rev() {
265            let j = rng.gen_range(0..=i);
266            idx.swap(i, j);
267        }
268        for i in idx.into_iter().take(alloc) {
269            out.push(stratum[i]);
270        }
271    }
272
273    // Final dedupe + cap. Sort for deterministic order; the input id-sort
274    // means any two IDs that landed in the same stratum can't collide here,
275    // but sorting also makes the output independent of stratum iteration
276    // order so the snapshot stays stable across HashMap reorderings.
277    out.sort();
278    out.dedup();
279    if out.len() > n as usize {
280        out.truncate(n as usize);
281    }
282    out
283}
284
285/// Score quality by sampling requests, comparing original vs proposed
286/// responses via the judge, and aggregating into a [`RiskBand`].
287///
288/// `proposed_response_for(id)` lets the caller plug in the
289/// proposed-model dispatch — production routes through the real provider,
290/// tests supply canned strings. Returning `None` skips that sample.
291///
292/// # Errors
293///
294/// - [`QualityError::BodyLoggingDisabled`] when
295///   `config.body_logging_enabled == false`.
296/// - [`QualityError::OverBudget`] when projected judge cost exceeds the
297///   configured budget. Computed before any judge call dispatches.
298/// - [`QualityError::Judge`] when a judge call fails.
299/// - [`QualityError::NoScorable`] when every sampled row was missing
300///   prompt body, response body, or a proposed response.
301pub async fn score_quality<F>(
302    requests: &[RequestLog],
303    config: &QualityConfig,
304    judge: &dyn JudgeProvider,
305    proposed_response_for: F,
306) -> Result<QualityResult, QualityError>
307where
308    F: Fn(&Uuid) -> Option<String>,
309{
310    if !config.body_logging_enabled {
311        return Err(QualityError::BodyLoggingDisabled);
312    }
313    let projected_cost = config.cost_per_judge_call_usd * f64::from(config.total_samples);
314    if projected_cost > config.budget_usd {
315        return Err(QualityError::OverBudget {
316            cost: projected_cost,
317            budget: config.budget_usd,
318        });
319    }
320
321    let sampled_ids = stratified_sample(requests, config.total_samples, config.seed);
322    let by_id: HashMap<Uuid, &RequestLog> = requests.iter().map(|r| (r.id, r)).collect();
323
324    let mut scores = Vec::new();
325    let mut acceptable: u32 = 0;
326    let mut degraded: u32 = 0;
327    let mut unclear: u32 = 0;
328
329    for id in &sampled_ids {
330        let Some(req) = by_id.get(id) else { continue };
331        let Some(input) = req.body.as_ref() else {
332            continue;
333        };
334        let Some(original) = req.response_body.as_ref() else {
335            continue;
336        };
337        let Some(proposed) = proposed_response_for(id) else {
338            continue;
339        };
340
341        let (verdict, mut reason) = judge.judge(input, original, &proposed).await?;
342        if reason.len() > 200 {
343            // Truncate at a char boundary to avoid panic on multi-byte text.
344            let mut cut = 200;
345            while cut > 0 && !reason.is_char_boundary(cut) {
346                cut -= 1;
347            }
348            reason.truncate(cut);
349        }
350        match verdict {
351            JudgeVerdict::Acceptable => acceptable += 1,
352            JudgeVerdict::Degraded => degraded += 1,
353            JudgeVerdict::Unclear => unclear += 1,
354        }
355        scores.push(SampleScore {
356            request_id: *id,
357            verdict,
358            reason,
359        });
360    }
361
362    if scores.is_empty() {
363        return Err(QualityError::NoScorable);
364    }
365
366    let total_classified = f64::from(acceptable + degraded);
367    let degraded_pct = if total_classified > 0.0 {
368        (f64::from(degraded) / total_classified) * 100.0
369    } else {
370        0.0
371    };
372    let risk_band = RiskBand::from_degraded_pct(degraded_pct);
373
374    let mut caveats = Vec::new();
375    let unclear_share = f64::from(unclear) / scores.len() as f64;
376    if unclear_share > 0.20 {
377        caveats.push(format!(
378            "{:.0}% of sampled requests were Unclear — the judge couldn't classify. \
379             Risk band may be unreliable; consider a stronger judge model.",
380            unclear_share * 100.0
381        ));
382    }
383    if scores.len() < 30 {
384        caveats.push(format!(
385            "Small quality sample ({} scored) — risk band has wide uncertainty.",
386            scores.len()
387        ));
388    }
389
390    Ok(QualityResult {
391        sample_size: scores.len() as u32,
392        acceptable_count: acceptable,
393        degraded_count: degraded,
394        unclear_count: unclear,
395        degraded_pct,
396        risk_band,
397        sampled_examples: scores,
398        caveats,
399    })
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405
406    #[test]
407    fn risk_band_thresholds() {
408        assert_eq!(RiskBand::from_degraded_pct(0.0), RiskBand::Low);
409        assert_eq!(RiskBand::from_degraded_pct(5.0), RiskBand::Low);
410        assert_eq!(RiskBand::from_degraded_pct(5.0001), RiskBand::Medium);
411        assert_eq!(RiskBand::from_degraded_pct(15.0), RiskBand::Medium);
412        assert_eq!(RiskBand::from_degraded_pct(15.0001), RiskBand::High);
413        assert_eq!(RiskBand::from_degraded_pct(100.0), RiskBand::High);
414    }
415}