Skip to main content

tt_plan_core/
replay.rs

1//! The deterministic replay loop. No async, no I/O, no clock reads — pure
2//! function of the input. Determinism is the contract enforced by
3//! `tests/replay.rs`.
4
5use std::collections::HashMap;
6
7use uuid::Uuid;
8
9use crate::{
10    bootstrap, cost,
11    error::PlanError,
12    routing,
13    types::{
14        Aggregates, ConfidenceIntervals, PerRouteBreakdown, PlanInput, PlanResult, ProposedRoute,
15        RequestLog,
16    },
17};
18
19/// The Plan replay entry point.
20///
21/// Pure function: same `(historical_rows, proposed_config, seed)` →
22/// bit-identical [`PlanResult`]. Determinism is the contract; the CI
23/// snapshot test in `tests/replay.rs` verifies it.
24///
25/// # Errors
26///
27/// Returns [`PlanError::InvalidWindow`] when `window_end <= window_start`
28/// and [`PlanError::ZeroBootstrapIterations`] when the caller passes
29/// `bootstrap_iterations = 0` (every CI would be `(0, 0)`, almost
30/// certainly a mistake).
31pub fn replay(input: PlanInput) -> Result<PlanResult, PlanError> {
32    validate(&input)?;
33
34    // Sort routes by priority descending — first match wins. Tie-break on
35    // the route's `id` (ascending) so equal-priority routes have a stable,
36    // config-intrinsic order independent of the caller's input array order.
37    // Without this, two logically-identical configs that differ only in the
38    // ordering of two equal-priority matching routes could resolve to
39    // different winners and thus different projected savings — violating the
40    // replay's "same config → bit-identical result" determinism contract.
41    let mut routes = input.proposed_routes.clone();
42    routes.sort_by(|a, b| b.priority.cmp(&a.priority).then_with(|| a.id.cmp(&b.id)));
43
44    // Walk requests in deterministic order (by id).
45    let mut requests = input.requests.clone();
46    requests.sort_by_key(|r| r.id);
47
48    // Project L1 cache hits (exact-match) under the proposed TTL. A projected
49    // hit serves the response for free, so its projected cost is zeroed in the
50    // cost loop — otherwise a cache-adding diff would show $0 savings.
51    // (L2 semantic-hit cost-zeroing is a follow-up: it needs a per-request hit
52    // set at the chosen threshold, and L2 isn't wired in the live gateway yet.)
53    let cache_hit_ids = crate::cache_projection::project_l1_hit_ids(&requests, &input.config);
54    let projection = project_requests(&requests, &routes, &input.pricing, &cache_hit_ids);
55
56    let mut aggregates = aggregate(&projection);
57
58    // Apply L1 cache projection from the proposed PlanConfig (if any TTL is set).
59    // This overrides the historical cache_hit_rate echo with the projected rate
60    // under the proposed TTL window — the actual answer to "would this cache
61    // config have helped?"
62    if input.config.l1_ttl_seconds.is_some() {
63        let proj = crate::cache_projection::project_l1_hits(&requests, &input.config);
64        aggregates.cache_hit_rate_projected = proj.projected_l1_hit_rate;
65    }
66
67    // Apply L2 (semantic) projection when any request in the window carries
68    // an embedding. The projection module short-circuits on missing
69    // configuration (TTL `None`, empty sweep, no embeddings present) — the
70    // outer guard here keeps the snapshot test stable on embedding-less
71    // fixtures by ensuring the new fields stay at their `Default::default()`.
72    if !requests.is_empty() && requests.iter().any(|r| r.embedding.is_some()) {
73        let l2 = crate::l2_projection::project_l2_hits(&requests, &input.config);
74        aggregates.l2_projections = l2.per_threshold;
75        aggregates.l2_poisoning_candidates = l2.poisoning_candidates;
76    }
77
78    let confidence_intervals = compute_cis(&projection, input.seed, input.bootstrap_iterations);
79    let per_route_breakdown = build_per_route(projection.per_route);
80
81    // Carry the proposed routes through to the result so the apply path can
82    // persist them. We move the *original* (unsorted) input vec rather than
83    // the priority-sorted `routes` clone above — apply re-sorts at write time
84    // and we want to preserve the caller's authored ordering for round-trip
85    // fidelity. This is a partial move out of `input`; the remaining fields
86    // read below (`plan_id`, `org_id`, the window bounds) are all `Copy`.
87    let proposed_routes = input.proposed_routes;
88
89    let mut caveats = build_caveats(
90        requests.len(),
91        aggregates.requests_unprice_able,
92        projection.latency_unprojected,
93    );
94    caveats.extend(wide_ci_caveats(&aggregates, &confidence_intervals));
95
96    Ok(PlanResult {
97        plan_id: input.plan_id,
98        org_id: input.org_id,
99        window_start: input.window_start,
100        window_end: input.window_end,
101        sample_size: requests.len() as u32,
102        aggregates,
103        confidence_intervals,
104        per_route_breakdown,
105        caveats,
106        // Tier 3 quality scoring is opt-in and dispatched via
107        // `replay_with_quality`; bare `replay()` returns `None` here so the
108        // existing JSON snapshot stays byte-identical.
109        quality: None,
110        proposed_routes,
111    })
112}
113
114/// Convenience helper that runs [`replay`] then attaches a Tier 3 quality
115/// score via [`crate::quality::score_quality`]. The CLI / hosted worker
116/// calls this when the org has body logging enabled and supplied a judge.
117///
118/// `proposed_response_for` is the caller-owned hook that re-runs the
119/// proposed model for a given request id; see
120/// [`crate::quality::score_quality`] for the contract.
121///
122/// # Errors
123///
124/// - Any error [`replay`] would return (validation failures).
125/// - Any error [`crate::quality::score_quality`] would return — these are
126///   *not* converted to [`PlanError`] because they're distinct conditions
127///   (opt-in gate, budget, judge failure) the caller surfaces differently.
128pub async fn replay_with_quality<F>(
129    input: PlanInput,
130    judge: &dyn crate::quality::JudgeProvider,
131    quality_config: &crate::quality::QualityConfig,
132    proposed_response_for: F,
133) -> Result<PlanResult, ReplayWithQualityError>
134where
135    F: Fn(&Uuid) -> Option<String>,
136{
137    // Clone the requests slice up front so we can hand it to both the
138    // sync replay and the async quality scorer without juggling ownership.
139    let requests = input.requests.clone();
140    let mut result = replay(input).map_err(ReplayWithQualityError::Replay)?;
141    let quality =
142        crate::quality::score_quality(&requests, quality_config, judge, proposed_response_for)
143            .await
144            .map_err(ReplayWithQualityError::Quality)?;
145    result.quality = Some(quality);
146    Ok(result)
147}
148
149/// Combined error envelope for [`replay_with_quality`]. Variants stay
150/// distinct so callers can render appropriate UX (`PlanError` is a
151/// validation/config failure; `QualityError` is a runtime / opt-in issue).
152#[derive(Debug, thiserror::Error)]
153pub enum ReplayWithQualityError {
154    /// The deterministic replay stage failed (invalid window, etc.).
155    #[error("replay: {0}")]
156    Replay(#[from] crate::error::PlanError),
157    /// The Tier 3 quality scoring stage failed (no opt-in, over budget,
158    /// judge error, …).
159    #[error("quality: {0}")]
160    Quality(#[from] crate::quality::QualityError),
161}
162
163fn validate(input: &PlanInput) -> Result<(), PlanError> {
164    if input.window_end <= input.window_start {
165        return Err(PlanError::InvalidWindow {
166            start: input.window_start.to_rfc3339(),
167            end: input.window_end.to_rfc3339(),
168        });
169    }
170    if input.bootstrap_iterations == 0 {
171        return Err(PlanError::ZeroBootstrapIterations);
172    }
173    Ok(())
174}
175
176/// Per-route accumulator built during the projection pass.
177struct PerRouteBucket {
178    route_id: Uuid,
179    route_name: String,
180    matched: u32,
181    baseline_cost_usd: f64,
182    projected_cost_usd: f64,
183}
184
185/// All the per-request vectors plus the per-route buckets the aggregation
186/// pass needs. Kept as a struct (not a tuple) so the field meanings are
187/// obvious at call sites.
188struct Projection {
189    per_request_baseline: Vec<f64>,
190    per_request_projected: Vec<f64>,
191    per_request_latency: Vec<f64>,
192    per_request_cache_hit: Vec<f64>,
193    per_route: HashMap<Uuid, PerRouteBucket>,
194    requests_rerouted: u32,
195    requests_unchanged: u32,
196    requests_unprice_able: u32,
197    /// Rerouted requests whose target model had no latency history in the
198    /// window — their latency is shown unchanged (can't be projected).
199    latency_unprojected: u32,
200}
201
202fn project_requests(
203    requests: &[RequestLog],
204    routes: &[ProposedRoute],
205    pricing: &crate::types::PricingTable,
206    cache_hit_ids: &std::collections::HashSet<Uuid>,
207) -> Projection {
208    let cap = requests.len();
209    let mut per_request_baseline = Vec::with_capacity(cap);
210    let mut per_request_projected = Vec::with_capacity(cap);
211    let mut per_request_latency = Vec::with_capacity(cap);
212    let mut per_request_cache_hit = Vec::with_capacity(cap);
213    let mut per_route: HashMap<Uuid, PerRouteBucket> = HashMap::new();
214    let mut requests_rerouted: u32 = 0;
215    let mut requests_unchanged: u32 = 0;
216    let mut requests_unprice_able: u32 = 0;
217    let mut latency_unprojected: u32 = 0;
218
219    // Median latency per model across the window — used to project a rerouted
220    // request's latency from its TARGET model's history rather than echoing the
221    // original (baseline) model's latency.
222    let model_medians = model_median_latencies(requests);
223
224    for req in requests {
225        per_request_baseline.push(req.baseline_cost_usd);
226        per_request_cache_hit.push(if req.cached { 1.0 } else { 0.0 });
227
228        // A projected cache hit serves the response for free regardless of
229        // routing, so its projected cost is 0.
230        let is_cache_hit = cache_hit_ids.contains(&req.id);
231
232        let matched = routing::match_route(req, routes);
233        match matched {
234            Some(route) => {
235                let target_key = crate::types::pricing_key(&req.provider, &route.then.target_model);
236                if let Some(p) = pricing.get(&target_key) {
237                    let projected = cost::project_cost(req, &route.then.target_model, p);
238                    let projected_cost = if is_cache_hit {
239                        0.0
240                    } else {
241                        projected.cost_usd
242                    };
243                    per_request_projected.push(projected_cost);
244                    // Project latency from the target model's window history;
245                    // fall back to the request's own latency (and flag it) when
246                    // the target model has no history to project from.
247                    match model_medians.get(route.then.target_model.as_str()) {
248                        Some(&med) => per_request_latency.push(med),
249                        None => {
250                            per_request_latency.push(f64::from(req.latency_ms));
251                            latency_unprojected += 1;
252                        }
253                    }
254                    let bucket = per_route.entry(route.id).or_insert_with(|| PerRouteBucket {
255                        route_id: route.id,
256                        route_name: route.name.clone(),
257                        matched: 0,
258                        baseline_cost_usd: 0.0,
259                        projected_cost_usd: 0.0,
260                    });
261                    bucket.matched += 1;
262                    bucket.baseline_cost_usd += req.baseline_cost_usd;
263                    bucket.projected_cost_usd += projected_cost;
264                    requests_rerouted += 1;
265                } else {
266                    // No pricing for the target model — count as unchanged.
267                    // Conservative invariant: never fabricate savings.
268                    per_request_projected.push(if is_cache_hit { 0.0 } else { req.cost_usd });
269                    per_request_latency.push(f64::from(req.latency_ms));
270                    requests_unprice_able += 1;
271                }
272            }
273            None => {
274                per_request_projected.push(if is_cache_hit { 0.0 } else { req.cost_usd });
275                per_request_latency.push(f64::from(req.latency_ms));
276                requests_unchanged += 1;
277            }
278        }
279    }
280
281    Projection {
282        per_request_baseline,
283        per_request_projected,
284        per_request_latency,
285        per_request_cache_hit,
286        per_route,
287        requests_rerouted,
288        requests_unchanged,
289        requests_unprice_able,
290        latency_unprojected,
291    }
292}
293
294/// Median latency (ms) per model across the window. Deterministic: sorts the
295/// per-model latencies and takes the upper-middle element. Empty input → empty
296/// map.
297fn model_median_latencies(requests: &[RequestLog]) -> HashMap<&str, f64> {
298    let mut by_model: HashMap<&str, Vec<u32>> = HashMap::new();
299    for r in requests {
300        by_model
301            .entry(r.model.as_str())
302            .or_default()
303            .push(r.latency_ms);
304    }
305    by_model
306        .into_iter()
307        .map(|(model, mut lat)| {
308            lat.sort_unstable();
309            (model, f64::from(lat[lat.len() / 2]))
310        })
311        .collect()
312}
313
314fn aggregate(p: &Projection) -> Aggregates {
315    let total_baseline: f64 = p.per_request_baseline.iter().sum();
316    let total_projected: f64 = p.per_request_projected.iter().sum();
317    let projected_savings = (total_baseline - total_projected).max(0.0);
318    let projected_savings_pct = if total_baseline > 0.0 {
319        projected_savings / total_baseline * 100.0
320    } else {
321        0.0
322    };
323    let cache_hit_rate = if p.per_request_cache_hit.is_empty() {
324        0.0
325    } else {
326        p.per_request_cache_hit.iter().sum::<f64>() / p.per_request_cache_hit.len() as f64
327    };
328    let p50_latency = percentile(&p.per_request_latency, 0.50);
329    let p95_latency = percentile(&p.per_request_latency, 0.95);
330
331    Aggregates {
332        total_baseline_cost_usd: total_baseline,
333        total_projected_cost_usd: total_projected,
334        projected_savings_usd: projected_savings,
335        projected_savings_pct,
336        cache_hit_rate_projected: cache_hit_rate,
337        p50_latency_ms_projected: p50_latency,
338        p95_latency_ms_projected: p95_latency,
339        requests_rerouted: p.requests_rerouted,
340        requests_unchanged: p.requests_unchanged,
341        requests_unprice_able: p.requests_unprice_able,
342        // L2 sweep + poisoning are populated downstream by `replay` when the
343        // window carries embeddings; default to empty/zero here.
344        l2_projections: Vec::new(),
345        l2_poisoning_candidates: 0,
346    }
347}
348
349fn compute_cis(p: &Projection, seed: u64, iterations: u32) -> ConfidenceIntervals {
350    // Savings (USD): bootstrap the per-request savings delta, scale the
351    // resampled MEAN back to a TOTAL by multiplying by the original n.
352    // (Each resample has the same n as the original, so mean × n = total.)
353    let n = p.per_request_baseline.len() as f64;
354    let savings_per_req: Vec<f64> = p
355        .per_request_baseline
356        .iter()
357        .zip(p.per_request_projected.iter())
358        .map(|(b, pr)| (b - pr).max(0.0))
359        .collect();
360    let (sv_lo_mean, sv_hi_mean) =
361        bootstrap::bootstrap_ci(&savings_per_req, seed, iterations, (0.025, 0.975));
362    let savings_usd_95 = (sv_lo_mean * n, sv_hi_mean * n);
363
364    // Savings pct: must bootstrap baseline + projected jointly because
365    // pct = (sum_b - sum_p) / sum_b is a ratio of two sums.
366    let savings_pct_95 = bootstrap_pct_savings_ci(
367        &p.per_request_baseline,
368        &p.per_request_projected,
369        seed.wrapping_add(1),
370        iterations,
371    );
372
373    // Cache hit rate: bootstrap the 0/1 hit vector — the mean of bools is
374    // exactly the hit rate.
375    let cache_hit_rate_95 = bootstrap::bootstrap_ci(
376        &p.per_request_cache_hit,
377        seed.wrapping_add(2),
378        iterations,
379        (0.025, 0.975),
380    );
381
382    // Latency percentile CIs: percentile-of-percentiles bootstrap — each
383    // resample computes its own p50/p95, then we take the 2.5/97.5 of those.
384    let p50_latency_ms_95 = bootstrap_percentile_ci(
385        &p.per_request_latency,
386        0.50,
387        seed.wrapping_add(3),
388        iterations,
389    );
390    let p95_latency_ms_95 = bootstrap_percentile_ci(
391        &p.per_request_latency,
392        0.95,
393        seed.wrapping_add(4),
394        iterations,
395    );
396
397    ConfidenceIntervals {
398        savings_usd_95,
399        savings_pct_95,
400        cache_hit_rate_95,
401        p50_latency_ms_95,
402        p95_latency_ms_95,
403    }
404}
405
406/// Bootstrap a CI on a quantile of `values`. Each iteration: resample with
407/// replacement, compute the requested percentile on the resample, collect.
408/// Return the 2.5/97.5 percentiles of those resampled-percentile values.
409fn bootstrap_percentile_ci(values: &[f64], q: f64, seed: u64, iterations: u32) -> (f64, f64) {
410    use rand::{Rng, SeedableRng};
411    use rand_chacha::ChaCha8Rng;
412    if values.is_empty() || iterations == 0 {
413        return (0.0, 0.0);
414    }
415    let n = values.len();
416    let mut rng = ChaCha8Rng::seed_from_u64(seed);
417    let mut samples: Vec<f64> = Vec::with_capacity(iterations as usize);
418    let mut buf: Vec<f64> = Vec::with_capacity(n);
419    for _ in 0..iterations {
420        buf.clear();
421        for _ in 0..n {
422            buf.push(values[rng.gen_range(0..n)]);
423        }
424        samples.push(percentile(&buf, q));
425    }
426    samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
427    let lo_idx = (0.025 * iterations as f64) as usize;
428    let hi_idx = ((0.975 * iterations as f64) as usize).min(iterations as usize - 1);
429    (samples[lo_idx], samples[hi_idx])
430}
431
432fn build_per_route(buckets: HashMap<Uuid, PerRouteBucket>) -> Vec<PerRouteBreakdown> {
433    let mut rows: Vec<PerRouteBreakdown> = buckets
434        .into_values()
435        .map(|b| PerRouteBreakdown {
436            route_id: b.route_id,
437            route_name: b.route_name,
438            matched: b.matched,
439            baseline_cost_usd: b.baseline_cost_usd,
440            projected_cost_usd: b.projected_cost_usd,
441            savings_usd: (b.baseline_cost_usd - b.projected_cost_usd).max(0.0),
442        })
443        .collect();
444    // Sort by route_id for determinism — savings-desc is unstable on ties
445    // and would break the snapshot test on slight float drift.
446    rows.sort_by_key(|r| r.route_id);
447    rows
448}
449
450fn build_caveats(
451    sample_size: usize,
452    requests_unprice_able: u32,
453    latency_unprojected: u32,
454) -> Vec<String> {
455    let mut caveats = Vec::new();
456    if sample_size < 1000 {
457        caveats.push(format!(
458            "Small sample size ({sample_size} requests) — confidence intervals are wide."
459        ));
460    }
461    if requests_unprice_able > 0 {
462        caveats.push(format!(
463            "{requests_unprice_able} request(s) routed to a target model with no pricing entry — counted as unchanged."
464        ));
465    }
466    if latency_unprojected > 0 {
467        caveats.push(format!(
468            "{latency_unprojected} rerouted request(s) had no latency history for the target model — their latency is shown unchanged, not projected."
469        ));
470    }
471    caveats
472}
473
474/// Relative CI width > 30% means the projection is too uncertain to act on.
475/// Called from `replay()` after CIs are computed so the caveat reflects the
476/// actual bootstrap result rather than just the sample size.
477pub(crate) fn wide_ci_caveats(aggregates: &Aggregates, cis: &ConfidenceIntervals) -> Vec<String> {
478    let mut out = Vec::new();
479    let rel_width = |lo: f64, hi: f64, center: f64| -> Option<f64> {
480        if center.abs() < f64::EPSILON {
481            return None;
482        }
483        Some((hi - lo).abs() / center.abs())
484    };
485    if let Some(w) = rel_width(
486        cis.savings_usd_95.0,
487        cis.savings_usd_95.1,
488        aggregates.projected_savings_usd,
489    ) {
490        if w > 0.30 {
491            out.push(format!(
492                "Savings CI is wide: ±{:.0}% relative width. Treat the headline savings number as a rough estimate; consider scanning a larger window.",
493                w * 100.0
494            ));
495        }
496    }
497    if let Some(w) = rel_width(
498        cis.p50_latency_ms_95.0,
499        cis.p50_latency_ms_95.1,
500        aggregates.p50_latency_ms_projected,
501    ) {
502        if w > 0.30 {
503            out.push(format!(
504                "p50 latency CI is wide: ±{:.0}% relative width.",
505                w * 100.0
506            ));
507        }
508    }
509    out
510}
511
512fn percentile(values: &[f64], q: f64) -> f64 {
513    if values.is_empty() {
514        return 0.0;
515    }
516    let mut v = values.to_vec();
517    v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
518    let idx = ((q * (v.len() as f64 - 1.0)).round() as usize).min(v.len() - 1);
519    v[idx]
520}
521
522/// Bootstrap the percentage-savings CI by jointly resampling baseline and
523/// projected costs. Distinct from `bootstrap_ci` because we need the
524/// ratio of two sums, not the mean of a single sample.
525fn bootstrap_pct_savings_ci(
526    baseline: &[f64],
527    projected: &[f64],
528    seed: u64,
529    iterations: u32,
530) -> (f64, f64) {
531    use rand::Rng;
532    use rand::SeedableRng;
533    use rand_chacha::ChaCha8Rng;
534
535    let n = baseline.len();
536    if n == 0 || iterations == 0 || n != projected.len() {
537        return (0.0, 0.0);
538    }
539    let mut rng = ChaCha8Rng::seed_from_u64(seed);
540    let mut pct_samples: Vec<f64> = Vec::with_capacity(iterations as usize);
541    for _ in 0..iterations {
542        let mut b_sum = 0.0;
543        let mut p_sum = 0.0;
544        for _ in 0..n {
545            let idx = rng.gen_range(0..n);
546            b_sum += baseline[idx];
547            p_sum += projected[idx];
548        }
549        let pct = if b_sum > 0.0 {
550            (b_sum - p_sum) / b_sum * 100.0
551        } else {
552            0.0
553        };
554        pct_samples.push(pct.max(0.0));
555    }
556    pct_samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
557    let iter_f = iterations as f64;
558    let lo_idx = ((0.025 * iter_f) as usize).min(pct_samples.len() - 1);
559    let hi_idx = ((0.975 * iter_f) as usize).min(pct_samples.len() - 1);
560    (pct_samples[lo_idx], pct_samples[hi_idx])
561}