Skip to main content

shadow_core/diff/
cost.rs

1//! Axis 6: cost (input+output tokens × per-model pricing).
2//!
3//! Pricing is richer than just (input, output). Modern frontier models
4//! bill differently for:
5//!
6//! - **Cached input tokens**: Anthropic prompt caching, OpenAI prompt
7//!   caching — typically 10% of the uncached rate.
8//! - **Reasoning tokens**: GPT-5+ reasoning / o1-style thinking tokens
9//!   exposed via `usage.completion_tokens_details.reasoning_tokens` —
10//!   usually billed at the output rate.
11//! - **Batch API**: OpenAI and Anthropic batch APIs are ~50% off.
12//!
13//! Unknown models still contribute 0 cost rather than phantom infinities.
14
15use std::collections::HashMap;
16
17use serde::{Deserialize, Serialize};
18
19use crate::agentlog::Record;
20use crate::diff::axes::{Axis, AxisStat, Flag};
21use crate::diff::bootstrap::{median, paired_ci};
22
23/// Per-model pricing in USD per token.
24#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
25pub struct ModelPricing {
26    /// USD per uncached input token.
27    pub input: f64,
28    /// USD per output token.
29    pub output: f64,
30    /// USD per cached input READ token (cache hit). For Anthropic this is
31    /// ~10% of `input`; for OpenAI ~50%. If 0.0, cached reads are billed
32    /// at the full input rate.
33    #[serde(default)]
34    pub cached_input: f64,
35    /// USD per cached input WRITE token, 5-minute TTL tier. Anthropic
36    /// charges ~1.25× input for 5m-ephemeral cache creation. If 0.0,
37    /// cache writes are billed at the uncached input rate.
38    #[serde(default)]
39    pub cached_write_5m: f64,
40    /// USD per cached input WRITE token, 1-hour TTL tier. Anthropic
41    /// charges ~2.0× input for 1h-ephemeral cache creation.
42    #[serde(default)]
43    pub cached_write_1h: f64,
44    /// USD per reasoning / thinking token. If 0.0, reasoning tokens
45    /// are billed at the `output` rate.
46    #[serde(default)]
47    pub reasoning: f64,
48    /// Multiplier applied to the final per-call cost when
49    /// `meta.batch == true` — e.g. 0.5 for a 50% batch discount.
50    /// 1.0 (or 0.0 as a sentinel) means no discount.
51    #[serde(default)]
52    pub batch_discount: f64,
53}
54
55impl ModelPricing {
56    /// Simple constructor that assumes no caching, no reasoning, no batch.
57    pub fn simple(input: f64, output: f64) -> Self {
58        Self {
59            input,
60            output,
61            cached_input: 0.0,
62            cached_write_5m: 0.0,
63            cached_write_1h: 0.0,
64            reasoning: 0.0,
65            batch_discount: 0.0,
66        }
67    }
68}
69
70/// Pricing table keyed by `chat_response.payload.model`.
71pub type Pricing = HashMap<String, ModelPricing>;
72
73/// Look up pricing for a model name with a fallback that strips a
74/// trailing dated snapshot suffix (`gpt-5-2025-08-07` →`gpt-5`,
75/// `claude-opus-4-7-20250219` → `claude-opus-4-7`). Provider SDKs
76/// almost always return the dated snapshot in chat-response bodies,
77/// while pricing tables are usually keyed by the bare alias.
78fn lookup_with_snapshot_fallback<'a>(
79    pricing: &'a Pricing,
80    model: &str,
81) -> Option<&'a ModelPricing> {
82    if let Some(p) = pricing.get(model) {
83        return Some(p);
84    }
85    if let Some(base) = strip_snapshot_tail(model) {
86        return pricing.get(base);
87    }
88    None
89}
90
91/// If `model` ends with a dated snapshot tail (`-YYYY-MM-DD` or `-YYYYMMDD`),
92/// return the prefix without it. Otherwise None.
93fn strip_snapshot_tail(model: &str) -> Option<&str> {
94    let bytes = model.as_bytes();
95    // Try -YYYY-MM-DD (length 11).
96    if bytes.len() > 11 && bytes[bytes.len() - 11] == b'-' {
97        let tail = &bytes[bytes.len() - 10..];
98        if tail.len() == 10
99            && tail[..4].iter().all(u8::is_ascii_digit)
100            && tail[4] == b'-'
101            && tail[5..7].iter().all(u8::is_ascii_digit)
102            && tail[7] == b'-'
103            && tail[8..10].iter().all(u8::is_ascii_digit)
104        {
105            return Some(&model[..model.len() - 11]);
106        }
107    }
108    // Try -YYYYMMDD (length 9).
109    if bytes.len() > 9 && bytes[bytes.len() - 9] == b'-' {
110        let tail = &bytes[bytes.len() - 8..];
111        if tail.len() == 8 && tail.iter().all(u8::is_ascii_digit) {
112            return Some(&model[..model.len() - 9]);
113        }
114    }
115    None
116}
117
118pub(crate) fn cost_of(r: &Record, pricing: &Pricing) -> Option<f64> {
119    let model = r.payload.get("model")?.as_str()?;
120    let usage = r.payload.get("usage")?;
121    let input = usage.get("input_tokens")?.as_f64()?;
122    let output = usage.get("output_tokens")?.as_f64()?;
123    let cached_input = usage
124        .get("cached_input_tokens")
125        .and_then(|v| v.as_f64())
126        .unwrap_or(0.0);
127    let cached_write_5m = usage
128        .get("cached_write_5m_tokens")
129        .and_then(|v| v.as_f64())
130        .unwrap_or(0.0);
131    let cached_write_1h = usage
132        .get("cached_write_1h_tokens")
133        .and_then(|v| v.as_f64())
134        .unwrap_or(0.0);
135    let thinking = usage
136        .get("thinking_tokens")
137        .and_then(|v| v.as_f64())
138        .unwrap_or(0.0);
139    // NaN-guard every token count. Malformed trace data must not produce
140    // phantom costs that poison the diff report (NaN propagates through
141    // the whole bootstrap and would silently land on the cost axis).
142    if !(input.is_finite()
143        && output.is_finite()
144        && cached_input.is_finite()
145        && cached_write_5m.is_finite()
146        && cached_write_1h.is_finite()
147        && thinking.is_finite())
148    {
149        return Some(0.0);
150    }
151    let Some(p) = lookup_with_snapshot_fallback(pricing, model) else {
152        return Some(0.0);
153    };
154    let cached_rate = if p.cached_input > 0.0 {
155        p.cached_input
156    } else {
157        p.input
158    };
159    let reasoning_rate = if p.reasoning > 0.0 {
160        p.reasoning
161    } else {
162        p.output
163    };
164    // Cache-write rates fall back to input rate if not set.
165    let write_5m_rate = if p.cached_write_5m > 0.0 {
166        p.cached_write_5m
167    } else {
168        p.input
169    };
170    let write_1h_rate = if p.cached_write_1h > 0.0 {
171        p.cached_write_1h
172    } else {
173        p.input
174    };
175    // input_tokens in the envelope is assumed to mean "uncached input".
176    // If the provider emits a combined count, callers should split by
177    // populating `cached_input_tokens` / `cached_write_5m_tokens` /
178    // `cached_write_1h_tokens` separately.
179    let mut cost = input * p.input
180        + cached_input * cached_rate
181        + cached_write_5m * write_5m_rate
182        + cached_write_1h * write_1h_rate
183        + output * p.output
184        + thinking * reasoning_rate;
185    // Batch API flag lives at the envelope-meta level, but to avoid
186    // coupling the cost axis to the Envelope type we also honor a
187    // `payload.meta_batch` convention. (Default: no discount.)
188    let batch = r
189        .payload
190        .get("batch")
191        .and_then(|v| v.as_bool())
192        .unwrap_or(false);
193    if batch && p.batch_discount > 0.0 {
194        cost *= p.batch_discount;
195    }
196    Some(cost)
197}
198
199/// True iff both records in the pair reference models that appear in
200/// the pricing table. Unknown models make `cost_of` fall back to 0.0
201/// (so the bootstrap stays numerically clean), but for the purpose of
202/// deciding whether the axis has real signal we need to distinguish
203/// "priced at zero" from "model not in table."
204fn pair_is_priced(br: &Record, cr: &Record, pricing: &Pricing) -> bool {
205    fn model_in_table(r: &Record, pricing: &Pricing) -> bool {
206        r.payload
207            .get("model")
208            .and_then(|m| m.as_str())
209            .is_some_and(|m| lookup_with_snapshot_fallback(pricing, m).is_some())
210    }
211    model_in_table(br, pricing) && model_in_table(cr, pricing)
212}
213
214/// Compute the cost axis.
215///
216/// Silent zeros are a footgun: if no pricing table is supplied, or the
217/// table has no entries for the traced models, every pair prices at
218/// `0.0` and the naive implementation reports `delta=0,
219/// severity=None` — indistinguishable from "both sides are genuinely
220/// free." To prevent that, we track how many pairs had both models
221/// present in the pricing table. When that count is zero (or below
222/// half of `pairs.len()`), we attach [`Flag::NoPricing`] so downstream
223/// renderers and reviewers see a caveat instead of a spurious clean
224/// bill of health.
225pub fn compute(pairs: &[(&Record, &Record)], pricing: &Pricing, seed: Option<u64>) -> AxisStat {
226    let mut b = Vec::with_capacity(pairs.len());
227    let mut c = Vec::with_capacity(pairs.len());
228    let mut priced_pairs = 0usize;
229    for (br, cr) in pairs {
230        if let (Some(bv), Some(cv)) = (cost_of(br, pricing), cost_of(cr, pricing)) {
231            b.push(bv);
232            c.push(cv);
233            if pair_is_priced(br, cr, pricing) {
234                priced_pairs += 1;
235            }
236        }
237    }
238    if b.is_empty() {
239        let mut stat = AxisStat::empty(Axis::Cost);
240        // An empty pair list means "nothing to compare" (n=0 already
241        // surfaces that). A non-empty list with no costable entries
242        // means malformed records — still worth flagging.
243        if !pairs.is_empty() {
244            stat.flags.push(Flag::NoPricing);
245        }
246        return stat;
247    }
248    let bm = median(&b);
249    let cm = median(&c);
250    let delta = cm - bm;
251    let ci = paired_ci(&b, &c, |bs, cs| median(cs) - median(bs), 0, seed);
252    let mut stat = AxisStat::new_value(Axis::Cost, bm, cm, delta, ci.low, ci.high, b.len());
253    // Flag whenever fewer than half of the input pairs had both models
254    // priced — the median is unreliable and the user likely forgot
255    // pricing for a subset of models. `priced_pairs == 0` means no
256    // pricing at all; the flag covers both cases.
257    if priced_pairs * 2 < pairs.len() {
258        stat.flags.push(Flag::NoPricing);
259    }
260    stat
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use crate::agentlog::Kind;
267    use crate::diff::axes::Severity;
268    use serde_json::json;
269
270    fn response(model: &str, input: u64, output: u64) -> Record {
271        Record::new(
272            Kind::ChatResponse,
273            json!({
274                "model": model,
275                "content": [],
276                "stop_reason": "end_turn",
277                "latency_ms": 0,
278                "usage": {"input_tokens": input, "output_tokens": output, "thinking_tokens": 0},
279            }),
280            "2026-04-21T10:00:00Z",
281            None,
282        )
283    }
284
285    fn response_with_usage(model: &str, usage: serde_json::Value) -> Record {
286        Record::new(
287            Kind::ChatResponse,
288            json!({
289                "model": model,
290                "content": [],
291                "stop_reason": "end_turn",
292                "latency_ms": 0,
293                "usage": usage,
294            }),
295            "2026-04-21T10:00:00Z",
296            None,
297        )
298    }
299
300    #[test]
301    fn pricing_lookup_drives_cost() {
302        let mut pricing = Pricing::new();
303        pricing.insert("opus".to_string(), ModelPricing::simple(0.000015, 0.000075));
304        pricing.insert(
305            "haiku".to_string(),
306            ModelPricing::simple(0.0000008, 0.000004),
307        );
308        let baseline: Vec<Record> = (0..10).map(|_| response("opus", 1000, 500)).collect();
309        let candidate: Vec<Record> = (0..10).map(|_| response("haiku", 1000, 500)).collect();
310        let pairs: Vec<(&Record, &Record)> = baseline.iter().zip(candidate.iter()).collect();
311        let stat = compute(&pairs, &pricing, Some(1));
312        assert!(stat.delta < 0.0);
313        assert_eq!(stat.severity, Severity::Severe);
314    }
315
316    #[test]
317    fn unknown_model_costs_zero() {
318        let pricing = Pricing::new();
319        let r = response("mystery", 1000, 500);
320        let pairs = [(&r, &r)];
321        let stat = compute(&pairs, &pricing, Some(1));
322        assert_eq!(stat.baseline_median, 0.0);
323    }
324
325    #[test]
326    fn no_pricing_flag_when_table_is_empty_but_pairs_exist() {
327        // Pairs present, pricing absent → flag so reviewers know the
328        // delta=0 is "unknown" and not "equal cost".
329        let pricing = Pricing::new();
330        let r = response("mystery", 1000, 500);
331        let pairs = [(&r, &r), (&r, &r), (&r, &r)];
332        let stat = compute(&pairs, &pricing, Some(1));
333        assert!(stat.flags.contains(&Flag::NoPricing));
334        assert_eq!(stat.delta, 0.0);
335    }
336
337    #[test]
338    fn no_pricing_flag_when_most_models_unpriced() {
339        // Pricing provided for one of three traced models → too many
340        // pairs unpriced, flag still fires (median is unreliable).
341        let mut pricing = Pricing::new();
342        pricing.insert("opus".to_string(), ModelPricing::simple(0.000015, 0.000075));
343        let priced = response("opus", 1000, 500);
344        let unpriced1 = response("sonnet-unlisted", 1000, 500);
345        let unpriced2 = response("gpt-x-unlisted", 1000, 500);
346        let pairs = [
347            (&priced, &priced),
348            (&unpriced1, &unpriced1),
349            (&unpriced2, &unpriced2),
350        ];
351        let stat = compute(&pairs, &pricing, Some(1));
352        // Only 1 of 3 pairs priced; half-or-more unpriced → flag fires.
353        assert!(stat.flags.contains(&Flag::NoPricing));
354    }
355
356    #[test]
357    fn no_pricing_flag_absent_when_all_pairs_priced() {
358        let mut pricing = Pricing::new();
359        pricing.insert("opus".to_string(), ModelPricing::simple(0.000015, 0.000075));
360        let r = response("opus", 1000, 500);
361        let pairs = [(&r, &r), (&r, &r)];
362        let stat = compute(&pairs, &pricing, Some(1));
363        assert!(!stat.flags.contains(&Flag::NoPricing));
364    }
365
366    #[test]
367    fn no_pricing_flag_absent_when_pairs_empty() {
368        // Empty pairs is a different story (nothing to compare), and the
369        // NoPricing flag would be misleading there — n=0 already says
370        // "no data." Assert we don't spuriously flag that case.
371        let pricing = Pricing::new();
372        let pairs: Vec<(&Record, &Record)> = Vec::new();
373        let stat = compute(&pairs, &pricing, Some(1));
374        assert!(!stat.flags.contains(&Flag::NoPricing));
375        assert_eq!(stat.n, 0);
376    }
377
378    #[test]
379    fn cached_input_tokens_billed_at_cheaper_rate() {
380        // 1000 uncached input @ $15/Mtok + 1000 cached @ $1.50/Mtok + 500 out @ $75/Mtok
381        let mut pricing = Pricing::new();
382        pricing.insert(
383            "opus".to_string(),
384            ModelPricing {
385                input: 0.000015,
386                output: 0.000075,
387                cached_input: 0.0000015, // 10% of input
388                cached_write_5m: 0.0,
389                cached_write_1h: 0.0,
390                reasoning: 0.0,
391                batch_discount: 0.0,
392            },
393        );
394        let r = response_with_usage(
395            "opus",
396            json!({
397                "input_tokens": 1000,
398                "output_tokens": 500,
399                "thinking_tokens": 0,
400                "cached_input_tokens": 1000,
401            }),
402        );
403        let pairs = [(&r, &r)];
404        let stat = compute(&pairs, &pricing, Some(1));
405        // 1000*0.000015 + 1000*0.0000015 + 500*0.000075 = 0.015 + 0.0015 + 0.0375 = 0.054
406        assert!((stat.baseline_median - 0.054).abs() < 1e-9);
407    }
408
409    #[test]
410    fn reasoning_tokens_billed_at_reasoning_rate() {
411        let mut pricing = Pricing::new();
412        pricing.insert(
413            "gpt-5".to_string(),
414            ModelPricing {
415                input: 0.000010,
416                output: 0.000040,
417                cached_input: 0.0,
418                cached_write_5m: 0.0,
419                cached_write_1h: 0.0,
420                reasoning: 0.000060, // reasoning costs more than output
421                batch_discount: 0.0,
422            },
423        );
424        let r = response_with_usage(
425            "gpt-5",
426            json!({
427                "input_tokens": 100,
428                "output_tokens": 100,
429                "thinking_tokens": 500,
430            }),
431        );
432        let pairs = [(&r, &r)];
433        let stat = compute(&pairs, &pricing, Some(1));
434        // 100*10e-6 + 100*40e-6 + 500*60e-6 = 0.001 + 0.004 + 0.030 = 0.035
435        assert!((stat.baseline_median - 0.035).abs() < 1e-6);
436    }
437
438    #[test]
439    fn anthropic_cache_write_tiers_are_billed_separately() {
440        // Opus: input=$15/Mtok, 5m-write=$18.75/Mtok (1.25x), 1h-write=$30/Mtok (2x)
441        let mut pricing = Pricing::new();
442        pricing.insert(
443            "opus".to_string(),
444            ModelPricing {
445                input: 0.000015,
446                output: 0.000075,
447                cached_input: 0.0000015,
448                cached_write_5m: 0.00001875,
449                cached_write_1h: 0.00003,
450                reasoning: 0.0,
451                batch_discount: 0.0,
452            },
453        );
454        let r = Record::new(
455            Kind::ChatResponse,
456            json!({
457                "model": "opus",
458                "content": [],
459                "stop_reason": "end_turn",
460                "latency_ms": 0,
461                "usage": {
462                    "input_tokens": 1000,
463                    "output_tokens": 200,
464                    "thinking_tokens": 0,
465                    "cached_input_tokens": 500,
466                    "cached_write_5m_tokens": 200,
467                    "cached_write_1h_tokens": 100,
468                },
469            }),
470            "2026-04-21T10:00:00Z",
471            None,
472        );
473        let pairs = [(&r, &r)];
474        let stat = compute(&pairs, &pricing, Some(1));
475        // True cost:
476        //   uncached input: 1000 * 15e-6    = 0.015
477        //   cached read:    500  * 1.5e-6   = 0.00075
478        //   5m write:       200  * 18.75e-6 = 0.00375
479        //   1h write:       100  * 30e-6    = 0.003
480        //   output:         200  * 75e-6    = 0.015
481        // Total: 0.0375
482        assert!(
483            (stat.baseline_median - 0.0375).abs() < 1e-6,
484            "got {}",
485            stat.baseline_median
486        );
487    }
488
489    #[test]
490    fn nan_usage_values_produce_zero_cost_not_phantom_inf() {
491        let mut pricing = Pricing::new();
492        pricing.insert("m".to_string(), ModelPricing::simple(0.001, 0.002));
493        let r = Record::new(
494            Kind::ChatResponse,
495            json!({
496                "model": "m",
497                "content": [],
498                "stop_reason": "end_turn",
499                "latency_ms": 0,
500                "usage": {
501                    "input_tokens": 100.0,
502                    "output_tokens": 100.0,
503                    "thinking_tokens": 0,
504                    "cached_input_tokens": f64::NAN,
505                },
506            }),
507            "2026-04-21T10:00:00Z",
508            None,
509        );
510        let pairs = [(&r, &r)];
511        let stat = compute(&pairs, &pricing, Some(1));
512        // Cost must be finite (0.0 by our NaN-guard policy), not NaN.
513        assert!(stat.baseline_median.is_finite());
514        assert_eq!(stat.severity, Severity::None);
515    }
516
517    #[test]
518    fn batch_flag_applies_discount() {
519        let mut pricing = Pricing::new();
520        pricing.insert(
521            "opus".to_string(),
522            ModelPricing {
523                input: 0.000015,
524                output: 0.000075,
525                cached_input: 0.0,
526                cached_write_5m: 0.0,
527                cached_write_1h: 0.0,
528                reasoning: 0.0,
529                batch_discount: 0.5, // 50% off for batch API
530            },
531        );
532        let batched = Record::new(
533            Kind::ChatResponse,
534            json!({
535                "model": "opus",
536                "content": [],
537                "stop_reason": "end_turn",
538                "latency_ms": 0,
539                "batch": true,
540                "usage": {"input_tokens": 1000, "output_tokens": 500, "thinking_tokens": 0},
541            }),
542            "2026-04-21T10:00:00Z",
543            None,
544        );
545        let non_batched = response("opus", 1000, 500);
546        let pairs_batched = [(&batched, &batched)];
547        let pairs_normal = [(&non_batched, &non_batched)];
548        let stat_b = compute(&pairs_batched, &pricing, Some(1));
549        let stat_n = compute(&pairs_normal, &pricing, Some(1));
550        assert!((stat_b.baseline_median - stat_n.baseline_median * 0.5).abs() < 1e-9);
551    }
552
553    #[test]
554    fn snapshot_tail_strips_iso_dates() {
555        assert_eq!(strip_snapshot_tail("gpt-5-2025-08-07"), Some("gpt-5"));
556        assert_eq!(
557            strip_snapshot_tail("gpt-4o-mini-2024-07-18"),
558            Some("gpt-4o-mini"),
559        );
560        // Anthropic-style packed YYYYMMDD.
561        assert_eq!(
562            strip_snapshot_tail("claude-opus-4-7-20250219"),
563            Some("claude-opus-4-7"),
564        );
565        // No tail.
566        assert_eq!(strip_snapshot_tail("gpt-5"), None);
567        assert_eq!(strip_snapshot_tail("gpt-4o-mini"), None);
568        // False positives not stripped (model name happens to end in digits).
569        assert_eq!(strip_snapshot_tail("o1"), None);
570    }
571
572    #[test]
573    fn cost_resolves_dated_snapshot_to_bare_alias() {
574        // Real-world bug: chat_response carries `gpt-5-2025-08-07` but
575        // pricing.json is keyed by `gpt-5`. Without the fallback the cost
576        // axis silently flagged `no_pricing` and reported zero delta.
577        let mut pricing = Pricing::new();
578        pricing.insert(
579            "gpt-5".to_string(),
580            ModelPricing {
581                input: 0.000010,
582                output: 0.000040,
583                cached_input: 0.0,
584                cached_write_5m: 0.0,
585                cached_write_1h: 0.0,
586                reasoning: 0.0,
587                batch_discount: 0.0,
588            },
589        );
590        let r = response("gpt-5-2025-08-07", 100, 50);
591        let cost = cost_of(&r, &pricing).unwrap();
592        // 100 * 1e-5 + 50 * 4e-5 = 0.001 + 0.002 = 0.003
593        assert!((cost - 0.003).abs() < 1e-9, "got {}", cost);
594        // pair_is_priced must also see the snapshot as priced.
595        let pairs = [(&r, &r)];
596        let stat = compute(&pairs, &pricing, Some(42));
597        assert!(
598            !stat.flags.contains(&Flag::NoPricing),
599            "pair_is_priced should accept dated snapshots"
600        );
601    }
602}