Skip to main content

shadow_core/diff/
drill_down.rs

1//! Per-pair drill-down: surfaces which specific turn in the paired
2//! trace set drove each aggregate axis regression.
3//!
4//! The nine-axis report is informative in aggregate but hides *where*
5//! the regression happened. A reviewer looking at a PR with 50 trace
6//! pairs sees `trajectory: delta +0.42, severe` but has to hand-audit
7//! every pair to find which ones actually regressed. This module
8//! computes per-pair, per-axis deltas and returns the top-K
9//! most-regressive pairs ranked by a normalised aggregate score.
10//!
11//! Design choices:
12//!
13//! 1. **No bootstrap per pair.** Bootstrap CIs are an aggregate-level
14//!    stat; per-pair we need only the raw deltas. This keeps drill-down
15//!    cheap (O(N) extraction per axis, no resampling).
16//!
17//! 2. **Self-contained extractors.** Rather than refactor the nine axis
18//!    modules to expose their per-pair internals, we re-implement the
19//!    (small) extractors here. Each is ≤ 20 lines; duplicating them
20//!    buys independence — drill-down's value function can evolve
21//!    without touching the statistical axis implementations.
22//!
23//! 3. **Normalised ranking.** Raw deltas have wildly different scales
24//!    (0-1 for semantic, 0-10000 for latency_ms). Each axis has a
25//!    per-axis `scale` used to normalise deltas into [0, ~1] so they
26//!    can be summed into a single regression score. Scales are
27//!    calibrated against the axis's Severity::Severe threshold so a
28//!    per-pair normalised delta of 1.0 corresponds roughly to one
29//!    severity-severe-sized movement.
30//!
31//! 4. **Text similarity via character-shingle Jaccard.** The aggregate
32//!    semantic axis uses corpus-level TF-IDF, which is meaningless per
33//!    singleton-pair. Character-shingle Jaccard (what alignment.rs uses)
34//!    is the same proxy first-divergence detection uses, so drill-down
35//!    results are internally consistent with the first-divergence row.
36//!
37//! 5. **Judge is skipped.** The Rust core never populates the Judge
38//!    axis (it's Python-side). Including it here would produce spurious
39//!    zeros for every pair.
40
41use std::collections::BTreeSet;
42
43use serde::{Deserialize, Serialize};
44
45use crate::agentlog::{Kind, Record};
46use crate::diff::axes::Axis;
47use crate::diff::cost::Pricing;
48
49/// Default number of pairs to surface in a drill-down list. Matches
50/// `alignment::DEFAULT_K` so downstream renderers can share the same
51/// "show 3 inline, collapse the rest" heuristic.
52pub const DEFAULT_K: usize = 5;
53
54/// One axis's contribution to a single pair's drill-down row.
55#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
56pub struct PairAxisScore {
57    /// Which axis this score describes.
58    pub axis: Axis,
59    /// Axis-specific baseline value in raw units (ms, tokens, USD,
60    /// similarity ratio, …).
61    pub baseline_value: f64,
62    /// Axis-specific candidate value in the same units as
63    /// `baseline_value`.
64    pub candidate_value: f64,
65    /// `candidate_value - baseline_value`. Sign is direction; magnitude
66    /// is raw axis units (ms, tokens, USD, …).
67    pub delta: f64,
68    /// `|delta| / axis_scale`, clamped to `[0, 4]`. Used as the per-axis
69    /// contribution to the pair's `regression_score`. 0 means "no
70    /// movement"; ~1 means "one severity-severe-sized movement on this
71    /// axis"; ≥ 2 is unambiguous regression.
72    pub normalized_delta: f64,
73}
74
75/// One pair's per-axis breakdown plus an aggregate regression score.
76#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
77pub struct PairDrilldown {
78    /// 0-based index into the paired-responses list.
79    pub pair_index: usize,
80    /// The turn number in the baseline trace (counting only chat_responses).
81    pub baseline_turn: usize,
82    /// The turn number in the candidate trace.
83    pub candidate_turn: usize,
84    /// Per-axis scores, in `Axis::all()` order (minus Judge).
85    pub axis_scores: Vec<PairAxisScore>,
86    /// Sum of `normalized_delta` across all included axes. Ranking key.
87    pub regression_score: f64,
88    /// The single axis that contributed the most to `regression_score`.
89    /// Useful for "the regression at turn 4 was a trajectory change" one-
90    /// liners in renderers.
91    pub dominant_axis: Axis,
92}
93
94/// Compute drill-down rows for every pair, return the top-`top_k`
95/// sorted by `regression_score` descending.
96///
97/// `top_k = 0` or `top_k >= pairs.len()` returns every pair.
98pub fn compute(
99    pairs: &[(&Record, &Record)],
100    pricing: &Pricing,
101    top_k: usize,
102) -> Vec<PairDrilldown> {
103    let mut rows: Vec<PairDrilldown> = pairs
104        .iter()
105        .enumerate()
106        .map(|(i, (b, c))| compute_pair(i, b, c, pricing))
107        .collect();
108    // Stable sort by regression_score descending; ties broken by
109    // pair_index ascending so output is deterministic.
110    rows.sort_by(|a, b| {
111        b.regression_score
112            .partial_cmp(&a.regression_score)
113            .unwrap_or(std::cmp::Ordering::Equal)
114            .then(a.pair_index.cmp(&b.pair_index))
115    });
116    if top_k > 0 && rows.len() > top_k {
117        rows.truncate(top_k);
118    }
119    rows
120}
121
122fn compute_pair(index: usize, b: &Record, c: &Record, pricing: &Pricing) -> PairDrilldown {
123    let scores: Vec<PairAxisScore> = vec![
124        axis_semantic(b, c),
125        axis_trajectory(b, c),
126        axis_safety(b, c),
127        axis_verbosity(b, c),
128        axis_latency(b, c),
129        axis_cost(b, c, pricing),
130        axis_reasoning(b, c),
131        axis_conformance(b, c),
132    ];
133    let regression_score: f64 = scores.iter().map(|s| s.normalized_delta).sum();
134    let dominant_axis = scores
135        .iter()
136        .max_by(|a, b| {
137            a.normalized_delta
138                .partial_cmp(&b.normalized_delta)
139                .unwrap_or(std::cmp::Ordering::Equal)
140        })
141        .map(|s| s.axis)
142        .unwrap_or(Axis::Semantic);
143    PairDrilldown {
144        pair_index: index,
145        baseline_turn: index,
146        candidate_turn: index,
147        axis_scores: scores,
148        regression_score,
149        dominant_axis,
150    }
151}
152
153// ---- per-axis extractors -------------------------------------------------
154//
155// Each returns a `PairAxisScore` with a normalised-delta scaled so that
156// 1.0 corresponds to one severity-severe-sized movement on that axis.
157// Scales are chosen to match the thresholds in `axes.rs::Severity::from_*`.
158
159/// Semantic axis: 1 − character-shingle-4 Jaccard similarity.
160/// Returns 0 for identical responses, 1 for totally disjoint.
161fn axis_semantic(b: &Record, c: &Record) -> PairAxisScore {
162    let sim = text_jaccard(&response_text(b), &response_text(c));
163    let delta = (1.0 - sim) - 0.0; // baseline "similarity to self" = 1
164    PairAxisScore {
165        axis: Axis::Semantic,
166        baseline_value: 1.0,
167        candidate_value: sim,
168        delta: sim - 1.0, // delta as "how far candidate sim drifted from 1"
169        normalized_delta: clamp_norm(delta / 0.5),
170    }
171}
172
173/// Trajectory axis: normalised Levenshtein over tool-shape sequence.
174fn axis_trajectory(b: &Record, c: &Record) -> PairAxisScore {
175    let bs = tool_shape_seq(b);
176    let cs = tool_shape_seq(c);
177    let div = normalised_edit_distance(&bs, &cs);
178    PairAxisScore {
179        axis: Axis::Trajectory,
180        baseline_value: 0.0,
181        candidate_value: div,
182        delta: div,
183        normalized_delta: clamp_norm(div / 0.5),
184    }
185}
186
187/// Safety axis: binary refusal indicator per side.
188fn axis_safety(b: &Record, c: &Record) -> PairAxisScore {
189    let br = is_refusal(b) as i32 as f64;
190    let cr = is_refusal(c) as i32 as f64;
191    PairAxisScore {
192        axis: Axis::Safety,
193        baseline_value: br,
194        candidate_value: cr,
195        delta: cr - br,
196        normalized_delta: clamp_norm((cr - br).abs()),
197    }
198}
199
200/// Verbosity axis: output_tokens.
201fn axis_verbosity(b: &Record, c: &Record) -> PairAxisScore {
202    let bv = output_tokens(b).unwrap_or(0.0);
203    let cv = output_tokens(c).unwrap_or(0.0);
204    PairAxisScore {
205        axis: Axis::Verbosity,
206        baseline_value: bv,
207        candidate_value: cv,
208        delta: cv - bv,
209        // One severe verbosity shift ≈ 100-token delta (calibrated
210        // against the severity thresholds in axes.rs).
211        normalized_delta: clamp_norm((cv - bv).abs() / 100.0),
212    }
213}
214
215/// Latency axis: latency_ms.
216fn axis_latency(b: &Record, c: &Record) -> PairAxisScore {
217    let bv = latency_ms(b).unwrap_or(0.0);
218    let cv = latency_ms(c).unwrap_or(0.0);
219    PairAxisScore {
220        axis: Axis::Latency,
221        baseline_value: bv,
222        candidate_value: cv,
223        delta: cv - bv,
224        // One severe latency shift ≈ 1000ms delta.
225        normalized_delta: clamp_norm((cv - bv).abs() / 1000.0),
226    }
227}
228
229/// Cost axis: tokens × pricing for this model.
230fn axis_cost(b: &Record, c: &Record, pricing: &Pricing) -> PairAxisScore {
231    let bc = cost_of(b, pricing);
232    let cc = cost_of(c, pricing);
233    PairAxisScore {
234        axis: Axis::Cost,
235        baseline_value: bc,
236        candidate_value: cc,
237        delta: cc - bc,
238        // One severe cost shift ≈ $0.01 delta per pair.
239        normalized_delta: clamp_norm((cc - bc).abs() / 0.01),
240    }
241}
242
243/// Reasoning axis: thinking_tokens from usage.
244fn axis_reasoning(b: &Record, c: &Record) -> PairAxisScore {
245    let bv = thinking_tokens(b).unwrap_or(0.0);
246    let cv = thinking_tokens(c).unwrap_or(0.0);
247    PairAxisScore {
248        axis: Axis::Reasoning,
249        baseline_value: bv,
250        candidate_value: cv,
251        delta: cv - bv,
252        normalized_delta: clamp_norm((cv - bv).abs() / 100.0),
253    }
254}
255
256/// Conformance axis: does the response body parse as JSON? Binary.
257fn axis_conformance(b: &Record, c: &Record) -> PairAxisScore {
258    let bp = parses_as_json(&response_text(b)) as i32 as f64;
259    let cp = parses_as_json(&response_text(c)) as i32 as f64;
260    PairAxisScore {
261        axis: Axis::Conformance,
262        baseline_value: bp,
263        candidate_value: cp,
264        delta: cp - bp,
265        // A loss of JSON parseability is an outright severe signal.
266        normalized_delta: clamp_norm((cp - bp).abs()),
267    }
268}
269
270// ---- small extractors (self-contained so drill-down owns its logic) ------
271
272fn response_text(r: &Record) -> String {
273    if r.kind != Kind::ChatResponse {
274        return String::new();
275    }
276    let arr = match r.payload.get("content").and_then(|c| c.as_array()) {
277        Some(a) => a,
278        None => return String::new(),
279    };
280    let mut out = String::new();
281    for part in arr {
282        if part.get("type").and_then(|t| t.as_str()) == Some("text") {
283            if let Some(t) = part.get("text").and_then(|t| t.as_str()) {
284                if !out.is_empty() {
285                    out.push('\n');
286                }
287                out.push_str(t);
288            }
289        }
290    }
291    out
292}
293
294fn tool_shape_seq(r: &Record) -> Vec<String> {
295    let arr = match r.payload.get("content").and_then(|c| c.as_array()) {
296        Some(a) => a,
297        None => return Vec::new(),
298    };
299    let mut out = Vec::new();
300    for part in arr {
301        if part.get("type").and_then(|t| t.as_str()) == Some("tool_use") {
302            let name = part
303                .get("name")
304                .and_then(|n| n.as_str())
305                .unwrap_or("_")
306                .to_string();
307            let mut keys: Vec<String> = part
308                .get("input")
309                .and_then(|i| i.as_object())
310                .map(|o| o.keys().cloned().collect())
311                .unwrap_or_default();
312            keys.sort();
313            out.push(format!("{name}({})", keys.join(",")));
314        }
315    }
316    out
317}
318
319fn latency_ms(r: &Record) -> Option<f64> {
320    r.payload.get("latency_ms").and_then(|v| v.as_f64())
321}
322
323fn output_tokens(r: &Record) -> Option<f64> {
324    r.payload
325        .get("usage")
326        .and_then(|u| u.get("output_tokens"))
327        .and_then(|v| v.as_f64())
328}
329
330fn thinking_tokens(r: &Record) -> Option<f64> {
331    r.payload
332        .get("usage")
333        .and_then(|u| u.get("thinking_tokens"))
334        .and_then(|v| v.as_f64())
335}
336
337fn is_refusal(r: &Record) -> bool {
338    match r.payload.get("stop_reason").and_then(|s| s.as_str()) {
339        Some("content_filter") | Some("refusal") => return true,
340        _ => {}
341    }
342    let text = response_text(r).to_lowercase();
343    // Conservative refusal indicators (matching safety axis heuristics).
344    text.contains("i can't help")
345        || text.contains("i cannot help")
346        || text.contains("i'm unable")
347        || text.contains("i am unable")
348        || text.contains("i won't")
349        || text.contains("i will not")
350}
351
352fn parses_as_json(text: &str) -> bool {
353    let trimmed = text.trim();
354    if trimmed.is_empty() {
355        return false;
356    }
357    // Accept values wrapped in code fences too (mirrors conformance
358    // axis's tolerance).
359    let unfenced = if let Some(s) = trimmed.strip_prefix("```json") {
360        s.trim().trim_end_matches("```").trim()
361    } else if let Some(s) = trimmed.strip_prefix("```") {
362        s.trim().trim_end_matches("```").trim()
363    } else {
364        trimmed
365    };
366    serde_json::from_str::<serde_json::Value>(unfenced).is_ok()
367}
368
369fn cost_of(r: &Record, pricing: &Pricing) -> f64 {
370    crate::diff::cost::cost_of(r, pricing).unwrap_or(0.0)
371}
372
373// ---- shared helpers ------------------------------------------------------
374
375fn clamp_norm(v: f64) -> f64 {
376    if v.is_nan() {
377        return 0.0;
378    }
379    v.abs().min(4.0)
380}
381
382fn text_jaccard(a: &str, b: &str) -> f64 {
383    let sa = shingles(a, 4);
384    let sb = shingles(b, 4);
385    if sa.is_empty() && sb.is_empty() {
386        return 1.0;
387    }
388    let inter = sa.intersection(&sb).count() as f64;
389    let uni = sa.union(&sb).count() as f64;
390    if uni == 0.0 {
391        1.0
392    } else {
393        inter / uni
394    }
395}
396
397fn shingles(s: &str, k: usize) -> BTreeSet<String> {
398    let chars: Vec<char> = s.chars().collect();
399    let mut out = BTreeSet::new();
400    if chars.len() < k {
401        if !s.is_empty() {
402            out.insert(s.to_string());
403        }
404        return out;
405    }
406    for w in chars.windows(k) {
407        out.insert(w.iter().collect());
408    }
409    out
410}
411
412fn normalised_edit_distance(a: &[String], b: &[String]) -> f64 {
413    let denom = a.len().max(b.len());
414    if denom == 0 {
415        return 0.0;
416    }
417    levenshtein(a, b) as f64 / denom as f64
418}
419
420fn levenshtein(a: &[String], b: &[String]) -> usize {
421    let (m, n) = (a.len(), b.len());
422    if m == 0 {
423        return n;
424    }
425    if n == 0 {
426        return m;
427    }
428    let mut prev: Vec<usize> = (0..=n).collect();
429    let mut curr = vec![0usize; n + 1];
430    for i in 1..=m {
431        curr[0] = i;
432        for j in 1..=n {
433            let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
434            curr[j] = (prev[j] + 1).min(curr[j - 1] + 1).min(prev[j - 1] + cost);
435        }
436        std::mem::swap(&mut prev, &mut curr);
437    }
438    prev[n]
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444    use crate::agentlog::Kind;
445    use serde_json::json;
446
447    fn resp(latency: u64, out_tokens: u64, text: &str) -> Record {
448        Record::new(
449            Kind::ChatResponse,
450            json!({
451                "model": "claude-sonnet-4-6",
452                "content": [{"type": "text", "text": text}],
453                "stop_reason": "end_turn",
454                "latency_ms": latency,
455                "usage": {
456                    "input_tokens": 10,
457                    "output_tokens": out_tokens,
458                    "thinking_tokens": 0,
459                },
460            }),
461            "2026-04-21T10:00:00Z",
462            None,
463        )
464    }
465
466    #[test]
467    fn identical_responses_have_zero_regression() {
468        let r = resp(100, 20, "hello world");
469        let pairs = vec![(&r, &r)];
470        let out = compute(&pairs, &Pricing::new(), 0);
471        assert_eq!(out.len(), 1);
472        assert!(
473            out[0].regression_score < 0.01,
474            "expected near-zero, got {}",
475            out[0].regression_score
476        );
477    }
478
479    #[test]
480    fn divergent_pair_scores_higher_than_matched_pair() {
481        let match_a = resp(100, 20, "hello world");
482        let match_b = resp(100, 20, "hello world");
483        let diverge_a = resp(100, 20, "hello world");
484        let diverge_b = resp(2500, 200, "totally different output");
485        let pairs = vec![(&match_a, &match_b), (&diverge_a, &diverge_b)];
486        let out = compute(&pairs, &Pricing::new(), 0);
487        assert_eq!(out.len(), 2);
488        // First in the sorted output is the divergent pair.
489        assert_eq!(out[0].pair_index, 1);
490        assert!(out[0].regression_score > out[1].regression_score);
491    }
492
493    #[test]
494    fn top_k_truncates_result_list() {
495        let rs: Vec<Record> = (0..10)
496            .map(|i| resp(100 + i * 50, 20, &format!("response {}", i)))
497            .collect();
498        let pairs: Vec<(&Record, &Record)> = rs.iter().zip(rs.iter().rev()).collect();
499        let out = compute(&pairs, &Pricing::new(), 3);
500        assert_eq!(out.len(), 3);
501    }
502
503    #[test]
504    fn ranking_is_deterministic_on_ties() {
505        // Two pairs with identical regression: tie-break by pair_index asc.
506        let a = resp(100, 20, "hello");
507        let b = resp(200, 30, "hello");
508        let pairs = vec![(&a, &b), (&a, &b), (&a, &b)];
509        let out1 = compute(&pairs, &Pricing::new(), 0);
510        let out2 = compute(&pairs, &Pricing::new(), 0);
511        assert_eq!(out1, out2);
512        assert_eq!(
513            out1.iter().map(|r| r.pair_index).collect::<Vec<_>>(),
514            vec![0, 1, 2]
515        );
516    }
517
518    #[test]
519    fn tool_shape_change_surfaces_trajectory_as_dominant() {
520        let baseline = Record::new(
521            Kind::ChatResponse,
522            json!({
523                "model": "x",
524                "content": [
525                    {"type": "tool_use", "name": "search", "input": {"query": "x"}},
526                ],
527                "stop_reason": "end_turn",
528                "latency_ms": 100,
529                "usage": {"input_tokens": 10, "output_tokens": 5, "thinking_tokens": 0},
530            }),
531            "ts",
532            None,
533        );
534        let candidate = Record::new(
535            Kind::ChatResponse,
536            json!({
537                "model": "x",
538                "content": [
539                    {"type": "tool_use", "name": "fetch", "input": {"url": "x"}},
540                ],
541                "stop_reason": "end_turn",
542                "latency_ms": 100,
543                "usage": {"input_tokens": 10, "output_tokens": 5, "thinking_tokens": 0},
544            }),
545            "ts",
546            None,
547        );
548        let pairs = vec![(&baseline, &candidate)];
549        let out = compute(&pairs, &Pricing::new(), 0);
550        assert_eq!(out.len(), 1);
551        assert_eq!(out[0].dominant_axis, Axis::Trajectory);
552    }
553
554    #[test]
555    fn refusal_surfaces_safety_axis() {
556        let b = resp(100, 20, "Here you go.");
557        let c = resp(100, 20, "I can't help with that.");
558        let pairs = vec![(&b, &c)];
559        let out = compute(&pairs, &Pricing::new(), 0);
560        let safety = out[0]
561            .axis_scores
562            .iter()
563            .find(|s| s.axis == Axis::Safety)
564            .unwrap();
565        assert!(
566            safety.normalized_delta > 0.5,
567            "expected safety axis to fire, got {}",
568            safety.normalized_delta
569        );
570    }
571
572    #[test]
573    fn json_loss_surfaces_conformance_axis() {
574        let b = resp(100, 20, r#"{"ok": true}"#);
575        let c = resp(100, 20, "sure thing");
576        let pairs = vec![(&b, &c)];
577        let out = compute(&pairs, &Pricing::new(), 0);
578        let conf = out[0]
579            .axis_scores
580            .iter()
581            .find(|s| s.axis == Axis::Conformance)
582            .unwrap();
583        assert_eq!(conf.baseline_value, 1.0);
584        assert_eq!(conf.candidate_value, 0.0);
585        assert!(conf.normalized_delta > 0.5);
586    }
587
588    #[test]
589    fn empty_pairs_returns_empty_vec() {
590        let pairs: Vec<(&Record, &Record)> = Vec::new();
591        let out = compute(&pairs, &Pricing::new(), 5);
592        assert!(out.is_empty());
593    }
594}