Skip to main content

shadow_core/diff/
axes.rs

1//! Shared types for the nine-axis behavioral diff.
2
3use serde::{Deserialize, Serialize};
4
5/// The nine behavioral axes.
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
7#[serde(rename_all = "snake_case")]
8pub enum Axis {
9    /// Final-output semantic similarity (embedding + structural).
10    Semantic,
11    /// Tool-call trajectory divergence (edit distance).
12    Trajectory,
13    /// Refusal / safety-filter rate.
14    Safety,
15    /// Output-token count CDF.
16    Verbosity,
17    /// End-to-end latency CDF.
18    Latency,
19    /// Cost distribution (tokens × pricing).
20    Cost,
21    /// Reasoning-depth (thinking tokens + self-correction markers).
22    Reasoning,
23    /// LLM-judge score (user-supplied rubric).
24    Judge,
25    /// Schema / format conformance rate.
26    Conformance,
27}
28
29impl Axis {
30    /// Human-readable label for terminal and markdown renderers.
31    pub fn label(&self) -> &'static str {
32        match self {
33            Axis::Semantic => "semantic similarity",
34            Axis::Trajectory => "tool-call trajectory",
35            Axis::Safety => "refusal / safety",
36            Axis::Verbosity => "verbosity",
37            Axis::Latency => "latency",
38            Axis::Cost => "cost",
39            Axis::Reasoning => "reasoning depth",
40            Axis::Judge => "llm-judge score",
41            Axis::Conformance => "format conformance",
42        }
43    }
44
45    /// All nine axes, in report order.
46    pub fn all() -> [Axis; 9] {
47        [
48            Axis::Semantic,
49            Axis::Trajectory,
50            Axis::Safety,
51            Axis::Verbosity,
52            Axis::Latency,
53            Axis::Cost,
54            Axis::Reasoning,
55            Axis::Judge,
56            Axis::Conformance,
57        ]
58    }
59}
60
61/// Severity classification of a per-axis delta.
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
63#[serde(rename_all = "snake_case")]
64pub enum Severity {
65    /// No meaningful difference (abs(delta) within CI noise).
66    None,
67    /// Small effect (detectable but within 10% relative).
68    Minor,
69    /// Notable (10–30% relative).
70    Moderate,
71    /// Large (>30% relative, or CI excludes zero by a wide margin).
72    Severe,
73}
74
75/// Caveat flags attached to a per-axis result.
76///
77/// Flags explain *why the severity is what it is* — they surface the
78/// statistical caveats that users would otherwise have to read the CI
79/// and `n` column to spot. A severity of `Severe` with no flags is
80/// strong; a severity of `Severe` with `LowPower` means "the trend
81/// looks large but our sample was too small to be confident."
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
83#[serde(rename_all = "snake_case")]
84pub enum Flag {
85    /// `n < 5` — we don't have enough paired observations to be
86    /// confident in the CI. Bootstrap on tiny samples produces wide
87    /// intervals; treat severities as directional, not definitive.
88    LowPower,
89    /// The 95% CI includes zero. At the 95% confidence level we cannot
90    /// reject "no effect"; the observed delta may be noise. Severity
91    /// is capped at `Minor` in this case regardless of `|delta|`.
92    CiCrossesZero,
93    /// Cost axis could not price any pair — either no pricing table was
94    /// supplied or the table has no entries for the traced models.
95    /// A `0.00` cost delta in this state means "unknown", not "equal".
96    /// Emitted only on [`Axis::Cost`].
97    NoPricing,
98}
99
100impl Flag {
101    /// Short machine-readable label for terminal / markdown / JSON output.
102    pub fn label(&self) -> &'static str {
103        match self {
104            Flag::LowPower => "low_power",
105            Flag::CiCrossesZero => "ci_crosses_zero",
106            Flag::NoPricing => "no_pricing",
107        }
108    }
109}
110
111/// Compute the caveat flags for an axis given its CI bounds and `n`.
112pub fn compute_flags(ci95_low: f64, ci95_high: f64, n: usize) -> Vec<Flag> {
113    let mut flags = Vec::new();
114    if n < 5 && n > 0 {
115        flags.push(Flag::LowPower);
116    }
117    // NaN-safe CI-straddles-zero check. Strict: a bound of exactly 0.0
118    // is a boundary artifact (rate axes saturated at 0 or 1), not
119    // genuine uncertainty about direction.
120    if ci95_low.is_finite() && ci95_high.is_finite() && ci95_low < -1e-9 && ci95_high > 1e-9 {
121        flags.push(Flag::CiCrossesZero);
122    }
123    flags
124}
125
126impl Severity {
127    /// Classify a delta given the axis's 95% CI bounds.
128    ///
129    /// Rules:
130    /// - If the CI crosses zero and the midpoint delta is small → None
131    /// - If the CI crosses zero with any larger delta → capped at Minor
132    ///   (we cannot reject "no effect" at 95%)
133    /// - else if abs(rel_delta) < 0.1 → Minor
134    /// - else if abs(rel_delta) < 0.3 → Moderate
135    /// - else Severe
136    ///
137    /// `baseline_median` may be zero; if so, Minor is returned when delta is
138    /// non-zero (avoiding divide-by-zero).
139    pub fn classify(delta: f64, baseline_median: f64, ci95_low: f64, ci95_high: f64) -> Severity {
140        // Reject NaN/Inf inputs explicitly. Rust's NaN comparisons always
141        // return false, which would silently fall through to the rel-based
142        // branches and return Severe on corrupt data — worst possible outcome
143        // for a diff report.
144        if !(delta.is_finite()
145            && baseline_median.is_finite()
146            && ci95_low.is_finite()
147            && ci95_high.is_finite())
148        {
149            return Severity::None;
150        }
151        if delta.abs() < 1e-9 {
152            // Exactly (or near-exactly) zero delta → nothing moved.
153            return Severity::None;
154        }
155        // CI "straddles zero" means we genuinely can't determine direction:
156        // both bounds are on opposite sides of zero by a meaningful margin.
157        // A bound of exactly 0.0 is a boundary artifact (rate-bounded axes,
158        // integer-valued statistics, etc.), not uncertainty — so don't
159        // downgrade on it.
160        let ci_straddles_zero = ci95_low < -1e-9 && ci95_high > 1e-9;
161        if ci_straddles_zero && delta.abs() < f64::max(baseline_median.abs() * 0.05, 1e-9) {
162            return Severity::None;
163        }
164        let base = if baseline_median.abs() < 1e-9 {
165            if delta.abs() < 1e-9 {
166                Severity::None
167            } else {
168                Severity::Minor
169            }
170        } else {
171            let rel = (delta / baseline_median).abs();
172            if rel < 0.10 {
173                Severity::Minor
174            } else if rel < 0.30 {
175                Severity::Moderate
176            } else {
177                Severity::Severe
178            }
179        };
180        // Only downgrade when the delta is small relative to CI width —
181        // i.e. the signal is weak compared to noise. A unanimous (|delta|
182        // ≥ CI width) observation should NOT be downgraded just because
183        // a bootstrap resample happened to cross zero.
184        if ci_straddles_zero && base > Severity::Minor {
185            let ci_width = (ci95_high - ci95_low).abs();
186            let delta_dominates = ci_width < 1e-9 || delta.abs() >= ci_width;
187            if !delta_dominates {
188                return Severity::Minor;
189            }
190        }
191        base
192    }
193
194    /// Classify a rate-like axis (values bounded in `[0, 1]`) by absolute
195    /// magnitude of the delta. Used by [`crate::diff::safety`] and
196    /// [`crate::diff::conformance`], where a shift from 0.0 → 0.33 is
197    /// "1/3 of traffic flipped" — real, not noise.
198    ///
199    /// Thresholds:
200    /// - CI crosses zero AND `|delta| < 1e-9` → None
201    /// - CI crosses zero with any larger delta → capped at Minor
202    /// - `|delta| < 0.05` → Minor
203    /// - `|delta| < 0.15` → Moderate
204    /// - else → Severe
205    pub fn classify_rate(delta: f64, ci95_low: f64, ci95_high: f64) -> Severity {
206        // NaN guard — see classify() above.
207        if !(delta.is_finite() && ci95_low.is_finite() && ci95_high.is_finite()) {
208            return Severity::None;
209        }
210        let abs = delta.abs();
211        if abs < 1e-9 {
212            return Severity::None;
213        }
214        // Strict straddling: a CI bound of exactly 0.0 is a boundary
215        // artifact for rate-axes bounded in [0,1] (e.g. saturated trajectory
216        // divergence where every pair has divergence 1.0 or 0.0), not
217        // statistical uncertainty.
218        let ci_straddles_zero = ci95_low < -1e-9 && ci95_high > 1e-9;
219        let base = if abs < 0.05 {
220            Severity::Minor
221        } else if abs < 0.15 {
222            Severity::Moderate
223        } else {
224            Severity::Severe
225        };
226        // Only downgrade when delta is small relative to CI width.
227        // Unanimous +1.0 delta with CI=[0,1] should remain Severe — the
228        // point estimate dominates the CI.
229        if ci_straddles_zero && base > Severity::Minor {
230            let ci_width = (ci95_high - ci95_low).abs();
231            let delta_dominates = ci_width < 1e-9 || abs >= ci_width;
232            if !delta_dominates {
233                return Severity::Minor;
234            }
235        }
236        base
237    }
238
239    /// Short string for reports: "none" / "minor" / "moderate" / "severe".
240    pub fn label(&self) -> &'static str {
241        match self {
242            Severity::None => "none",
243            Severity::Minor => "minor",
244            Severity::Moderate => "moderate",
245            Severity::Severe => "severe",
246        }
247    }
248}
249
250/// One axis's statistical result.
251#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
252pub struct AxisStat {
253    /// Which axis this row describes.
254    pub axis: Axis,
255    /// Median of the baseline sample.
256    pub baseline_median: f64,
257    /// Median of the candidate sample.
258    pub candidate_median: f64,
259    /// `candidate_median - baseline_median`.
260    pub delta: f64,
261    /// Lower bound of the 95% bootstrap CI of the delta.
262    pub ci95_low: f64,
263    /// Upper bound of the 95% bootstrap CI of the delta.
264    pub ci95_high: f64,
265    /// Severity classification per [`Severity::classify`].
266    pub severity: Severity,
267    /// Number of paired observations the axis was computed from. Zero
268    /// means the axis had nothing to measure (e.g. no tool calls in
269    /// either side → Trajectory axis is `n=0`).
270    pub n: usize,
271    /// Caveat flags — e.g. `low_power` (n<5) or `ci_crosses_zero`.
272    #[serde(default)]
273    pub flags: Vec<Flag>,
274}
275
276impl AxisStat {
277    /// Build a "no data" row — used when the axis had nothing to measure.
278    pub fn empty(axis: Axis) -> Self {
279        Self {
280            axis,
281            baseline_median: 0.0,
282            candidate_median: 0.0,
283            delta: 0.0,
284            ci95_low: 0.0,
285            ci95_high: 0.0,
286            severity: Severity::None,
287            n: 0,
288            flags: Vec::new(),
289        }
290    }
291
292    /// Build an axis row for a continuous-valued axis (latency, verbosity,
293    /// cost, reasoning, ...). Severity uses the relative-delta thresholds
294    /// via [`Severity::classify`]; flags come from [`compute_flags`].
295    #[allow(clippy::too_many_arguments)]
296    pub fn new_value(
297        axis: Axis,
298        baseline_median: f64,
299        candidate_median: f64,
300        delta: f64,
301        ci95_low: f64,
302        ci95_high: f64,
303        n: usize,
304    ) -> Self {
305        Self {
306            axis,
307            baseline_median,
308            candidate_median,
309            delta,
310            ci95_low,
311            ci95_high,
312            severity: Severity::classify(delta, baseline_median, ci95_low, ci95_high),
313            n,
314            flags: compute_flags(ci95_low, ci95_high, n),
315        }
316    }
317
318    /// Build an axis row for a rate-like axis (values in `[0, 1]`:
319    /// safety, conformance). Uses [`Severity::classify_rate`].
320    #[allow(clippy::too_many_arguments)]
321    pub fn new_rate(
322        axis: Axis,
323        baseline_median: f64,
324        candidate_median: f64,
325        delta: f64,
326        ci95_low: f64,
327        ci95_high: f64,
328        n: usize,
329    ) -> Self {
330        Self {
331            axis,
332            baseline_median,
333            candidate_median,
334            delta,
335            ci95_low,
336            ci95_high,
337            severity: Severity::classify_rate(delta, ci95_low, ci95_high),
338            n,
339            flags: compute_flags(ci95_low, ci95_high, n),
340        }
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn axis_all_has_nine_entries() {
350        assert_eq!(Axis::all().len(), 9);
351    }
352
353    #[test]
354    fn axis_labels_are_unique() {
355        let mut labels: Vec<&str> = Axis::all().iter().map(|a| a.label()).collect();
356        labels.sort();
357        labels.dedup();
358        assert_eq!(labels.len(), 9);
359    }
360
361    #[test]
362    fn severity_ci_crossing_zero_with_tiny_delta_is_none() {
363        // Noise around a baseline of 100 with CI spanning zero.
364        assert_eq!(Severity::classify(1.0, 100.0, -10.0, 10.0), Severity::None);
365    }
366
367    #[test]
368    fn severity_minor_moderate_severe_thresholds() {
369        assert_eq!(Severity::classify(5.0, 100.0, 3.0, 7.0), Severity::Minor); // 5% rel
370        assert_eq!(
371            Severity::classify(20.0, 100.0, 15.0, 25.0),
372            Severity::Moderate
373        ); // 20% rel
374        assert_eq!(
375            Severity::classify(50.0, 100.0, 40.0, 60.0),
376            Severity::Severe
377        ); // 50% rel
378    }
379
380    #[test]
381    fn severity_capped_at_minor_when_ci_crosses_zero_with_small_delta() {
382        // CI wider than delta: the observation is dominated by noise, so
383        // cap at Minor even though the rel-delta would be Severe.
384        // delta=50, ci=[-200, +300] (width=500, delta/width=0.1 — noisy)
385        assert_eq!(
386            Severity::classify(50.0, 100.0, -200.0, 300.0),
387            Severity::Minor
388        );
389    }
390
391    #[test]
392    fn severity_not_downgraded_when_delta_dominates_ci_even_if_ci_straddles_zero() {
393        // Would-be Severe (100% rel). CI [-5, +95] strictly straddles zero
394        // but delta (100) ≥ CI width (100). Delta dominates — stays Severe.
395        // This catches the v0.1 bug where saturated axes with legitimate
396        // large deltas got reported as Minor.
397        assert_eq!(
398            Severity::classify(100.0, 100.0, -5.0, 95.0),
399            Severity::Severe
400        );
401    }
402
403    #[test]
404    fn severity_rate_unanimous_saturated_delta_stays_severe() {
405        // REGRESSION TEST: trajectory axis saturated at +1.0 (every pair
406        // had 100% divergence). Bootstrap CI can legitimately touch
407        // ci_low=0.0 due to rate-bounded resampling, but this is a
408        // boundary artifact, not uncertainty. Must NOT be capped at Minor.
409        assert_eq!(Severity::classify_rate(1.0, 0.0, 1.0), Severity::Severe);
410    }
411
412    #[test]
413    fn severity_rate_capped_when_ci_genuinely_straddles_and_delta_is_small() {
414        // Small delta + wide CI straddling zero → truly ambiguous, cap Minor.
415        // delta=0.2, CI=[-0.3, +0.7] (width 1.0, delta/width=0.2).
416        assert_eq!(Severity::classify_rate(0.2, -0.3, 0.7), Severity::Minor);
417    }
418
419    #[test]
420    fn compute_flags_detects_low_power_and_ci_crosses_zero() {
421        // Genuine straddle: [-1, +1] with n=3 → both LowPower + CiCrossesZero
422        assert_eq!(
423            compute_flags(-1.0, 1.0, 3),
424            vec![Flag::LowPower, Flag::CiCrossesZero]
425        );
426        assert_eq!(compute_flags(0.5, 1.0, 3), vec![Flag::LowPower]);
427        assert_eq!(compute_flags(-1.0, 1.0, 50), vec![Flag::CiCrossesZero]);
428        assert!(compute_flags(0.5, 1.0, 50).is_empty());
429    }
430
431    #[test]
432    fn compute_flags_does_not_flag_boundary_touching_ci() {
433        // Rate axis saturated at 0: ci=[0,1] is a boundary-touching CI, not
434        // a straddle. Should NOT flag ci_crosses_zero.
435        assert!(!compute_flags(0.0, 1.0, 50).contains(&Flag::CiCrossesZero));
436        // Symmetric case: negative-saturated ci=[-1, 0].
437        assert!(!compute_flags(-1.0, 0.0, 50).contains(&Flag::CiCrossesZero));
438    }
439
440    #[test]
441    fn severity_classify_rejects_nan_inputs() {
442        // NaN comparisons always false in Rust — guard required or we'd
443        // silently return Severe on corrupt data.
444        assert_eq!(
445            Severity::classify(f64::NAN, 100.0, -10.0, 10.0),
446            Severity::None
447        );
448        assert_eq!(
449            Severity::classify(5.0, f64::NAN, -10.0, 10.0),
450            Severity::None
451        );
452        assert_eq!(
453            Severity::classify(5.0, 100.0, f64::NAN, 10.0),
454            Severity::None
455        );
456        assert_eq!(
457            Severity::classify(5.0, 100.0, -10.0, f64::INFINITY),
458            Severity::None
459        );
460    }
461
462    #[test]
463    fn severity_classify_rate_rejects_nan_inputs() {
464        assert_eq!(Severity::classify_rate(f64::NAN, 0.0, 1.0), Severity::None);
465        assert_eq!(Severity::classify_rate(0.5, f64::NAN, 1.0), Severity::None);
466    }
467
468    #[test]
469    fn compute_flags_ignores_nan_ci_bounds() {
470        // NaN CI doesn't count as crossing zero (the inequality is undefined).
471        let flags = compute_flags(f64::NAN, 1.0, 10);
472        assert!(!flags.contains(&Flag::CiCrossesZero));
473    }
474
475    #[test]
476    fn compute_flags_n_zero_means_no_low_power() {
477        // n=0 is an "axis had nothing to measure" sentinel, not "tiny sample".
478        assert!(compute_flags(0.0, 0.0, 0).is_empty());
479    }
480
481    #[test]
482    fn severity_handles_zero_baseline() {
483        assert_eq!(Severity::classify(0.0, 0.0, 0.0, 0.0), Severity::None);
484        assert_eq!(Severity::classify(1.0, 0.0, 0.5, 1.5), Severity::Minor);
485    }
486
487    #[test]
488    fn severity_labels_distinguish_four_levels() {
489        let labels: Vec<&str> = [
490            Severity::None,
491            Severity::Minor,
492            Severity::Moderate,
493            Severity::Severe,
494        ]
495        .iter()
496        .map(|s| s.label())
497        .collect();
498        assert_eq!(labels, vec!["none", "minor", "moderate", "severe"]);
499    }
500
501    #[test]
502    fn axis_stat_empty_has_zero_n() {
503        let s = AxisStat::empty(Axis::Latency);
504        assert_eq!(s.n, 0);
505        assert_eq!(s.severity, Severity::None);
506    }
507}