tt_plan_core/quality.rs
1//! Tier 3 LLM-judge quality scoring for Plan projections.
2//!
3//! The judge call is the most expensive operation in the Plan pipeline
4//! ($0.05–$0.50/scan per spec §18 budget). This module owns the stratified
5//! sampling, judge dispatch, and risk-band aggregation. Production wires it
6//! to a real LLM provider; tests drive it with a deterministic [`MockJudge`].
7//!
8//! # Hard invariants
9//!
10//! 1. **Opt-in only**: quality sampling reads full request bodies (per
11//! ADR-009 / spec §11). [`score_quality`] refuses to run when
12//! [`QualityConfig::body_logging_enabled`] is `false`.
13//! 2. **Stratified sampling**: by `(tag, size_bucket)` — bucket on
14//! `input_tokens`. Proportional allocation, capped at total budget.
15//! 3. **Deterministic**: same `(requests, config, seed)` → bit-identical
16//! risk band + sampled request IDs.
17//! 4. **Judge agnostic**: scoring uses any [`JudgeProvider`] impl.
18//! 5. **Risk thresholds** (per spec §7.4): `LOW` if `degraded ≤ 5%`,
19//! `MEDIUM` if `5% < degraded ≤ 15%`, `HIGH` if `degraded > 15%`.
20//! 6. **Cost-capped**: refuses to dispatch when projected cost exceeds
21//! [`QualityConfig::budget_usd`].
22//!
23//! # Scope
24//!
25//! Re-running the proposed model is *out of scope* for [`score_quality`] —
26//! the caller provides a `proposed_response_for(&Uuid) -> Option<String>`
27//! closure so production can dispatch to the real provider while tests can
28//! supply a canned map. Library + judge contract; dispatch is caller-owned.
29
30use std::collections::HashMap;
31
32use async_trait::async_trait;
33use rand::Rng;
34use rand::SeedableRng;
35use rand_chacha::ChaCha8Rng;
36use serde::{Deserialize, Serialize};
37use thiserror::Error;
38use uuid::Uuid;
39
40use crate::types::RequestLog;
41
42/// Aggregate risk classification per `docs/03-plan-replay-design.md` §7.4
43/// (task-spec thresholds: `≤5%` / `(5%, 15%]` / `>15%`).
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
45#[serde(rename_all = "lowercase")]
46pub enum RiskBand {
47 /// Degraded share of *classified* (non-Unclear) samples ≤ 5%.
48 Low,
49 /// Degraded share in `(5%, 15%]`.
50 Medium,
51 /// Degraded share > 15%.
52 High,
53}
54
55impl RiskBand {
56 /// Map a degraded percentage (0–100) to a band. Boundary policy:
57 /// `≤ 5%` → `Low`, `≤ 15%` → `Medium`, otherwise `High`.
58 #[must_use]
59 pub fn from_degraded_pct(p: f64) -> Self {
60 if p <= 5.0 {
61 Self::Low
62 } else if p <= 15.0 {
63 Self::Medium
64 } else {
65 Self::High
66 }
67 }
68}
69
70/// Per-sample judge verdict.
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
72#[serde(rename_all = "lowercase")]
73pub enum JudgeVerdict {
74 /// Proposed response is interchangeable with the original.
75 Acceptable,
76 /// Proposed response is materially worse than the original.
77 Degraded,
78 /// Judge declined to classify (model refusal, parse failure, etc.).
79 /// Counted toward neither acceptable nor degraded but recorded in the
80 /// total. A high `Unclear` share surfaces as a user-visible caveat.
81 Unclear,
82}
83
84/// One sampled request's score.
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct SampleScore {
87 /// Stable identifier of the source [`RequestLog`].
88 pub request_id: Uuid,
89 /// The judge's classification.
90 pub verdict: JudgeVerdict,
91 /// Best-effort one-line reason from the judge. Trimmed to 200 chars.
92 pub reason: String,
93}
94
95/// Aggregated quality result attached to a `PlanResult`.
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct QualityResult {
98 /// Number of requests the judge actually scored.
99 pub sample_size: u32,
100 /// Count of `Acceptable` verdicts.
101 pub acceptable_count: u32,
102 /// Count of `Degraded` verdicts.
103 pub degraded_count: u32,
104 /// Count of `Unclear` verdicts.
105 pub unclear_count: u32,
106 /// `degraded_count / (acceptable_count + degraded_count) × 100` (0–100).
107 /// Defined over *classified* samples only — `Unclear` is excluded from
108 /// the denominator because by definition we don't know its valence.
109 pub degraded_pct: f64,
110 /// Aggregate band — feeds the user-facing red/yellow/green pill.
111 pub risk_band: RiskBand,
112 /// Per-sample scores, in stable order. Bounded by
113 /// [`QualityConfig::total_samples`].
114 pub sampled_examples: Vec<SampleScore>,
115 /// Human-readable warnings (small sample, high `Unclear` share, etc.).
116 pub caveats: Vec<String>,
117}
118
119/// Configuration for one Tier 3 quality scoring run.
120#[derive(Debug, Clone, Serialize, Deserialize)]
121pub struct QualityConfig {
122 /// Required gate: caller must surface body-logging opt-in to the org.
123 /// `false` causes [`score_quality`] to error with
124 /// [`QualityError::BodyLoggingDisabled`].
125 pub body_logging_enabled: bool,
126 /// Total samples to draw across all strata (cap).
127 pub total_samples: u32,
128 /// Hard cost ceiling for judge calls in this run (USD).
129 pub budget_usd: f64,
130 /// Estimated USD per judge call (typically $0.001–$0.01 for a Sonnet
131 /// judge). Multiplied by `total_samples` for the up-front budget check.
132 pub cost_per_judge_call_usd: f64,
133 /// Random seed for stratified sampling. Determinism contract.
134 pub seed: u64,
135}
136
137/// Errors surfaced by [`score_quality`].
138#[derive(Debug, Error)]
139pub enum QualityError {
140 /// Caller invoked scoring without the body-logging opt-in. Tier 3
141 /// requires raw prompts + responses on each sampled row.
142 #[error("body logging not opted in by org — Tier 3 quality scoring requires raw bodies")]
143 BodyLoggingDisabled,
144
145 /// Pre-flight estimate `cost_per_judge_call_usd × total_samples`
146 /// exceeded the caller's budget. Scoring did not dispatch.
147 #[error("projected judge cost ${cost:.4} exceeds budget ${budget:.4}")]
148 OverBudget {
149 /// Projected cost in USD.
150 cost: f64,
151 /// Configured budget ceiling in USD.
152 budget: f64,
153 },
154
155 /// The [`JudgeProvider`] failed mid-run. Holds the provider's message.
156 #[error("judge: {0}")]
157 Judge(String),
158
159 /// Every sampled row was missing either the prompt body or the
160 /// historical response body — nothing was sent to the judge.
161 #[error("no sampled requests carry both prompt + response bodies")]
162 NoScorable,
163}
164
165/// Pluggable judge backend. Production: an LLM provider call. Tests:
166/// [`MockJudge`].
167#[async_trait]
168pub trait JudgeProvider: Send + Sync {
169 /// Compare an "original" response and a "proposed" response for the same
170 /// input. Return a verdict + one-line reason.
171 ///
172 /// # Errors
173 ///
174 /// Returns [`QualityError::Judge`] when the underlying provider fails
175 /// in a way the caller should surface to the user (e.g. auth failure,
176 /// rate-limit exhaustion). Implementations that recover internally via
177 /// retry should return [`JudgeVerdict::Unclear`] rather than erroring.
178 async fn judge(
179 &self,
180 input_body: &str,
181 original_response: &str,
182 proposed_response: &str,
183 ) -> Result<(JudgeVerdict, String), QualityError>;
184}
185
186/// Deterministic mock — used by replay tests and any caller that wants to
187/// stub the judge for offline analysis.
188pub struct MockJudge {
189 /// Force every call to return this verdict.
190 pub verdict: JudgeVerdict,
191 /// Reason string echoed verbatim on every call.
192 pub reason: String,
193}
194
195#[async_trait]
196impl JudgeProvider for MockJudge {
197 async fn judge(
198 &self,
199 _input: &str,
200 _orig: &str,
201 _prop: &str,
202 ) -> Result<(JudgeVerdict, String), QualityError> {
203 Ok((self.verdict, self.reason.clone()))
204 }
205}
206
207/// Compute the `(tag, size_bucket)` stratum for a request. Bucket boundaries
208/// match `docs/03-plan-replay-design.md` §7.1 (`small ≤ 500`, `medium ≤ 4000`,
209/// `large > 4000` input tokens).
210fn stratify(req: &RequestLog) -> (Option<String>, &'static str) {
211 let bucket = match req.input_tokens {
212 0..=500 => "small",
213 501..=4000 => "medium",
214 _ => "large",
215 };
216 (req.tag.clone(), bucket)
217}
218
219/// Draw `n` requests stratified by `(tag, size_bucket)` proportional to the
220/// stratum's share of the input population. Deterministic given `seed` and
221/// the input slice order (we re-sort by `id` internally so callers don't
222/// have to).
223///
224/// Returns the sampled IDs in ascending order. When `n == 0` or `requests`
225/// is empty, returns an empty `Vec`. Rounding can produce ≤ `n` samples
226/// (never more — final truncation enforces the cap).
227#[must_use]
228pub fn stratified_sample(requests: &[RequestLog], n: u32, seed: u64) -> Vec<Uuid> {
229 if n == 0 || requests.is_empty() {
230 return Vec::new();
231 }
232
233 // Index requests into deterministic strata. Sort the input by id first
234 // so the per-stratum vectors are populated in a deterministic order
235 // even when callers pass arbitrary orderings.
236 let mut sorted: Vec<&RequestLog> = requests.iter().collect();
237 sorted.sort_by_key(|r| r.id);
238
239 let mut by_stratum: HashMap<(Option<String>, &'static str), Vec<Uuid>> = HashMap::new();
240 for r in &sorted {
241 by_stratum.entry(stratify(r)).or_default().push(r.id);
242 }
243
244 // Proportional allocation: each stratum gets `n × (stratum_size / total)`
245 // (rounded). Iterate strata in sorted key order so the RNG draws happen
246 // in a deterministic sequence.
247 let total = requests.len() as f64;
248 let n_f = f64::from(n);
249 let mut keys: Vec<_> = by_stratum.keys().cloned().collect();
250 keys.sort();
251
252 let mut rng = ChaCha8Rng::seed_from_u64(seed);
253 let mut out = Vec::new();
254 for k in keys {
255 let stratum = &by_stratum[&k];
256 let alloc = ((stratum.len() as f64 / total) * n_f).round() as usize;
257 let alloc = alloc.min(stratum.len());
258 if alloc == 0 {
259 continue;
260 }
261 // Fisher–Yates partial shuffle so the first `alloc` indices are a
262 // uniform without-replacement draw.
263 let mut idx: Vec<usize> = (0..stratum.len()).collect();
264 for i in (1..idx.len()).rev() {
265 let j = rng.gen_range(0..=i);
266 idx.swap(i, j);
267 }
268 for i in idx.into_iter().take(alloc) {
269 out.push(stratum[i]);
270 }
271 }
272
273 // Final dedupe + cap. Sort for deterministic order; the input id-sort
274 // means any two IDs that landed in the same stratum can't collide here,
275 // but sorting also makes the output independent of stratum iteration
276 // order so the snapshot stays stable across HashMap reorderings.
277 out.sort();
278 out.dedup();
279 if out.len() > n as usize {
280 out.truncate(n as usize);
281 }
282 out
283}
284
285/// Score quality by sampling requests, comparing original vs proposed
286/// responses via the judge, and aggregating into a [`RiskBand`].
287///
288/// `proposed_response_for(id)` lets the caller plug in the
289/// proposed-model dispatch — production routes through the real provider,
290/// tests supply canned strings. Returning `None` skips that sample.
291///
292/// # Errors
293///
294/// - [`QualityError::BodyLoggingDisabled`] when
295/// `config.body_logging_enabled == false`.
296/// - [`QualityError::OverBudget`] when projected judge cost exceeds the
297/// configured budget. Computed before any judge call dispatches.
298/// - [`QualityError::Judge`] when a judge call fails.
299/// - [`QualityError::NoScorable`] when every sampled row was missing
300/// prompt body, response body, or a proposed response.
301pub async fn score_quality<F>(
302 requests: &[RequestLog],
303 config: &QualityConfig,
304 judge: &dyn JudgeProvider,
305 proposed_response_for: F,
306) -> Result<QualityResult, QualityError>
307where
308 F: Fn(&Uuid) -> Option<String>,
309{
310 if !config.body_logging_enabled {
311 return Err(QualityError::BodyLoggingDisabled);
312 }
313 let projected_cost = config.cost_per_judge_call_usd * f64::from(config.total_samples);
314 if projected_cost > config.budget_usd {
315 return Err(QualityError::OverBudget {
316 cost: projected_cost,
317 budget: config.budget_usd,
318 });
319 }
320
321 let sampled_ids = stratified_sample(requests, config.total_samples, config.seed);
322 let by_id: HashMap<Uuid, &RequestLog> = requests.iter().map(|r| (r.id, r)).collect();
323
324 let mut scores = Vec::new();
325 let mut acceptable: u32 = 0;
326 let mut degraded: u32 = 0;
327 let mut unclear: u32 = 0;
328
329 for id in &sampled_ids {
330 let Some(req) = by_id.get(id) else { continue };
331 let Some(input) = req.body.as_ref() else {
332 continue;
333 };
334 let Some(original) = req.response_body.as_ref() else {
335 continue;
336 };
337 let Some(proposed) = proposed_response_for(id) else {
338 continue;
339 };
340
341 let (verdict, mut reason) = judge.judge(input, original, &proposed).await?;
342 if reason.len() > 200 {
343 // Truncate at a char boundary to avoid panic on multi-byte text.
344 let mut cut = 200;
345 while cut > 0 && !reason.is_char_boundary(cut) {
346 cut -= 1;
347 }
348 reason.truncate(cut);
349 }
350 match verdict {
351 JudgeVerdict::Acceptable => acceptable += 1,
352 JudgeVerdict::Degraded => degraded += 1,
353 JudgeVerdict::Unclear => unclear += 1,
354 }
355 scores.push(SampleScore {
356 request_id: *id,
357 verdict,
358 reason,
359 });
360 }
361
362 if scores.is_empty() {
363 return Err(QualityError::NoScorable);
364 }
365
366 let total_classified = f64::from(acceptable + degraded);
367 let degraded_pct = if total_classified > 0.0 {
368 (f64::from(degraded) / total_classified) * 100.0
369 } else {
370 0.0
371 };
372 let risk_band = RiskBand::from_degraded_pct(degraded_pct);
373
374 let mut caveats = Vec::new();
375 let unclear_share = f64::from(unclear) / scores.len() as f64;
376 if unclear_share > 0.20 {
377 caveats.push(format!(
378 "{:.0}% of sampled requests were Unclear — the judge couldn't classify. \
379 Risk band may be unreliable; consider a stronger judge model.",
380 unclear_share * 100.0
381 ));
382 }
383 if scores.len() < 30 {
384 caveats.push(format!(
385 "Small quality sample ({} scored) — risk band has wide uncertainty.",
386 scores.len()
387 ));
388 }
389
390 Ok(QualityResult {
391 sample_size: scores.len() as u32,
392 acceptable_count: acceptable,
393 degraded_count: degraded,
394 unclear_count: unclear,
395 degraded_pct,
396 risk_band,
397 sampled_examples: scores,
398 caveats,
399 })
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405
406 #[test]
407 fn risk_band_thresholds() {
408 assert_eq!(RiskBand::from_degraded_pct(0.0), RiskBand::Low);
409 assert_eq!(RiskBand::from_degraded_pct(5.0), RiskBand::Low);
410 assert_eq!(RiskBand::from_degraded_pct(5.0001), RiskBand::Medium);
411 assert_eq!(RiskBand::from_degraded_pct(15.0), RiskBand::Medium);
412 assert_eq!(RiskBand::from_degraded_pct(15.0001), RiskBand::High);
413 assert_eq!(RiskBand::from_degraded_pct(100.0), RiskBand::High);
414 }
415}