1use std::collections::HashMap;
6
7use uuid::Uuid;
8
9use crate::{
10 bootstrap, cost,
11 error::PlanError,
12 routing,
13 types::{
14 Aggregates, ConfidenceIntervals, PerRouteBreakdown, PlanInput, PlanResult, ProposedRoute,
15 RequestLog,
16 },
17};
18
19pub fn replay(input: PlanInput) -> Result<PlanResult, PlanError> {
32 validate(&input)?;
33
34 let mut routes = input.proposed_routes.clone();
42 routes.sort_by(|a, b| b.priority.cmp(&a.priority).then_with(|| a.id.cmp(&b.id)));
43
44 let mut requests = input.requests.clone();
46 requests.sort_by_key(|r| r.id);
47
48 let cache_hit_ids = crate::cache_projection::project_l1_hit_ids(&requests, &input.config);
54 let projection = project_requests(&requests, &routes, &input.pricing, &cache_hit_ids);
55
56 let mut aggregates = aggregate(&projection);
57
58 if input.config.l1_ttl_seconds.is_some() {
63 let proj = crate::cache_projection::project_l1_hits(&requests, &input.config);
64 aggregates.cache_hit_rate_projected = proj.projected_l1_hit_rate;
65 }
66
67 if !requests.is_empty() && requests.iter().any(|r| r.embedding.is_some()) {
73 let l2 = crate::l2_projection::project_l2_hits(&requests, &input.config);
74 aggregates.l2_projections = l2.per_threshold;
75 aggregates.l2_poisoning_candidates = l2.poisoning_candidates;
76 }
77
78 let confidence_intervals = compute_cis(&projection, input.seed, input.bootstrap_iterations);
79 let per_route_breakdown = build_per_route(projection.per_route);
80
81 let proposed_routes = input.proposed_routes;
88
89 let mut caveats = build_caveats(
90 requests.len(),
91 aggregates.requests_unprice_able,
92 projection.latency_unprojected,
93 );
94 caveats.extend(wide_ci_caveats(&aggregates, &confidence_intervals));
95
96 Ok(PlanResult {
97 plan_id: input.plan_id,
98 org_id: input.org_id,
99 window_start: input.window_start,
100 window_end: input.window_end,
101 sample_size: requests.len() as u32,
102 aggregates,
103 confidence_intervals,
104 per_route_breakdown,
105 caveats,
106 quality: None,
110 proposed_routes,
111 })
112}
113
114pub async fn replay_with_quality<F>(
129 input: PlanInput,
130 judge: &dyn crate::quality::JudgeProvider,
131 quality_config: &crate::quality::QualityConfig,
132 proposed_response_for: F,
133) -> Result<PlanResult, ReplayWithQualityError>
134where
135 F: Fn(&Uuid) -> Option<String>,
136{
137 let requests = input.requests.clone();
140 let mut result = replay(input).map_err(ReplayWithQualityError::Replay)?;
141 let quality =
142 crate::quality::score_quality(&requests, quality_config, judge, proposed_response_for)
143 .await
144 .map_err(ReplayWithQualityError::Quality)?;
145 result.quality = Some(quality);
146 Ok(result)
147}
148
149#[derive(Debug, thiserror::Error)]
153pub enum ReplayWithQualityError {
154 #[error("replay: {0}")]
156 Replay(#[from] crate::error::PlanError),
157 #[error("quality: {0}")]
160 Quality(#[from] crate::quality::QualityError),
161}
162
163fn validate(input: &PlanInput) -> Result<(), PlanError> {
164 if input.window_end <= input.window_start {
165 return Err(PlanError::InvalidWindow {
166 start: input.window_start.to_rfc3339(),
167 end: input.window_end.to_rfc3339(),
168 });
169 }
170 if input.bootstrap_iterations == 0 {
171 return Err(PlanError::ZeroBootstrapIterations);
172 }
173 Ok(())
174}
175
176struct PerRouteBucket {
178 route_id: Uuid,
179 route_name: String,
180 matched: u32,
181 baseline_cost_usd: f64,
182 projected_cost_usd: f64,
183}
184
185struct Projection {
189 per_request_baseline: Vec<f64>,
190 per_request_projected: Vec<f64>,
191 per_request_latency: Vec<f64>,
192 per_request_cache_hit: Vec<f64>,
193 per_route: HashMap<Uuid, PerRouteBucket>,
194 requests_rerouted: u32,
195 requests_unchanged: u32,
196 requests_unprice_able: u32,
197 latency_unprojected: u32,
200}
201
202fn project_requests(
203 requests: &[RequestLog],
204 routes: &[ProposedRoute],
205 pricing: &crate::types::PricingTable,
206 cache_hit_ids: &std::collections::HashSet<Uuid>,
207) -> Projection {
208 let cap = requests.len();
209 let mut per_request_baseline = Vec::with_capacity(cap);
210 let mut per_request_projected = Vec::with_capacity(cap);
211 let mut per_request_latency = Vec::with_capacity(cap);
212 let mut per_request_cache_hit = Vec::with_capacity(cap);
213 let mut per_route: HashMap<Uuid, PerRouteBucket> = HashMap::new();
214 let mut requests_rerouted: u32 = 0;
215 let mut requests_unchanged: u32 = 0;
216 let mut requests_unprice_able: u32 = 0;
217 let mut latency_unprojected: u32 = 0;
218
219 let model_medians = model_median_latencies(requests);
223
224 for req in requests {
225 per_request_baseline.push(req.baseline_cost_usd);
226 per_request_cache_hit.push(if req.cached { 1.0 } else { 0.0 });
227
228 let is_cache_hit = cache_hit_ids.contains(&req.id);
231
232 let matched = routing::match_route(req, routes);
233 match matched {
234 Some(route) => {
235 let target_key = crate::types::pricing_key(&req.provider, &route.then.target_model);
236 if let Some(p) = pricing.get(&target_key) {
237 let projected = cost::project_cost(req, &route.then.target_model, p);
238 let projected_cost = if is_cache_hit {
239 0.0
240 } else {
241 projected.cost_usd
242 };
243 per_request_projected.push(projected_cost);
244 match model_medians.get(route.then.target_model.as_str()) {
248 Some(&med) => per_request_latency.push(med),
249 None => {
250 per_request_latency.push(f64::from(req.latency_ms));
251 latency_unprojected += 1;
252 }
253 }
254 let bucket = per_route.entry(route.id).or_insert_with(|| PerRouteBucket {
255 route_id: route.id,
256 route_name: route.name.clone(),
257 matched: 0,
258 baseline_cost_usd: 0.0,
259 projected_cost_usd: 0.0,
260 });
261 bucket.matched += 1;
262 bucket.baseline_cost_usd += req.baseline_cost_usd;
263 bucket.projected_cost_usd += projected_cost;
264 requests_rerouted += 1;
265 } else {
266 per_request_projected.push(if is_cache_hit { 0.0 } else { req.cost_usd });
269 per_request_latency.push(f64::from(req.latency_ms));
270 requests_unprice_able += 1;
271 }
272 }
273 None => {
274 per_request_projected.push(if is_cache_hit { 0.0 } else { req.cost_usd });
275 per_request_latency.push(f64::from(req.latency_ms));
276 requests_unchanged += 1;
277 }
278 }
279 }
280
281 Projection {
282 per_request_baseline,
283 per_request_projected,
284 per_request_latency,
285 per_request_cache_hit,
286 per_route,
287 requests_rerouted,
288 requests_unchanged,
289 requests_unprice_able,
290 latency_unprojected,
291 }
292}
293
294fn model_median_latencies(requests: &[RequestLog]) -> HashMap<&str, f64> {
298 let mut by_model: HashMap<&str, Vec<u32>> = HashMap::new();
299 for r in requests {
300 by_model
301 .entry(r.model.as_str())
302 .or_default()
303 .push(r.latency_ms);
304 }
305 by_model
306 .into_iter()
307 .map(|(model, mut lat)| {
308 lat.sort_unstable();
309 (model, f64::from(lat[lat.len() / 2]))
310 })
311 .collect()
312}
313
314fn aggregate(p: &Projection) -> Aggregates {
315 let total_baseline: f64 = p.per_request_baseline.iter().sum();
316 let total_projected: f64 = p.per_request_projected.iter().sum();
317 let projected_savings = (total_baseline - total_projected).max(0.0);
318 let projected_savings_pct = if total_baseline > 0.0 {
319 projected_savings / total_baseline * 100.0
320 } else {
321 0.0
322 };
323 let cache_hit_rate = if p.per_request_cache_hit.is_empty() {
324 0.0
325 } else {
326 p.per_request_cache_hit.iter().sum::<f64>() / p.per_request_cache_hit.len() as f64
327 };
328 let p50_latency = percentile(&p.per_request_latency, 0.50);
329 let p95_latency = percentile(&p.per_request_latency, 0.95);
330
331 Aggregates {
332 total_baseline_cost_usd: total_baseline,
333 total_projected_cost_usd: total_projected,
334 projected_savings_usd: projected_savings,
335 projected_savings_pct,
336 cache_hit_rate_projected: cache_hit_rate,
337 p50_latency_ms_projected: p50_latency,
338 p95_latency_ms_projected: p95_latency,
339 requests_rerouted: p.requests_rerouted,
340 requests_unchanged: p.requests_unchanged,
341 requests_unprice_able: p.requests_unprice_able,
342 l2_projections: Vec::new(),
345 l2_poisoning_candidates: 0,
346 }
347}
348
349fn compute_cis(p: &Projection, seed: u64, iterations: u32) -> ConfidenceIntervals {
350 let n = p.per_request_baseline.len() as f64;
354 let savings_per_req: Vec<f64> = p
355 .per_request_baseline
356 .iter()
357 .zip(p.per_request_projected.iter())
358 .map(|(b, pr)| (b - pr).max(0.0))
359 .collect();
360 let (sv_lo_mean, sv_hi_mean) =
361 bootstrap::bootstrap_ci(&savings_per_req, seed, iterations, (0.025, 0.975));
362 let savings_usd_95 = (sv_lo_mean * n, sv_hi_mean * n);
363
364 let savings_pct_95 = bootstrap_pct_savings_ci(
367 &p.per_request_baseline,
368 &p.per_request_projected,
369 seed.wrapping_add(1),
370 iterations,
371 );
372
373 let cache_hit_rate_95 = bootstrap::bootstrap_ci(
376 &p.per_request_cache_hit,
377 seed.wrapping_add(2),
378 iterations,
379 (0.025, 0.975),
380 );
381
382 let p50_latency_ms_95 = bootstrap_percentile_ci(
385 &p.per_request_latency,
386 0.50,
387 seed.wrapping_add(3),
388 iterations,
389 );
390 let p95_latency_ms_95 = bootstrap_percentile_ci(
391 &p.per_request_latency,
392 0.95,
393 seed.wrapping_add(4),
394 iterations,
395 );
396
397 ConfidenceIntervals {
398 savings_usd_95,
399 savings_pct_95,
400 cache_hit_rate_95,
401 p50_latency_ms_95,
402 p95_latency_ms_95,
403 }
404}
405
406fn bootstrap_percentile_ci(values: &[f64], q: f64, seed: u64, iterations: u32) -> (f64, f64) {
410 use rand::{Rng, SeedableRng};
411 use rand_chacha::ChaCha8Rng;
412 if values.is_empty() || iterations == 0 {
413 return (0.0, 0.0);
414 }
415 let n = values.len();
416 let mut rng = ChaCha8Rng::seed_from_u64(seed);
417 let mut samples: Vec<f64> = Vec::with_capacity(iterations as usize);
418 let mut buf: Vec<f64> = Vec::with_capacity(n);
419 for _ in 0..iterations {
420 buf.clear();
421 for _ in 0..n {
422 buf.push(values[rng.gen_range(0..n)]);
423 }
424 samples.push(percentile(&buf, q));
425 }
426 samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
427 let lo_idx = (0.025 * iterations as f64) as usize;
428 let hi_idx = ((0.975 * iterations as f64) as usize).min(iterations as usize - 1);
429 (samples[lo_idx], samples[hi_idx])
430}
431
432fn build_per_route(buckets: HashMap<Uuid, PerRouteBucket>) -> Vec<PerRouteBreakdown> {
433 let mut rows: Vec<PerRouteBreakdown> = buckets
434 .into_values()
435 .map(|b| PerRouteBreakdown {
436 route_id: b.route_id,
437 route_name: b.route_name,
438 matched: b.matched,
439 baseline_cost_usd: b.baseline_cost_usd,
440 projected_cost_usd: b.projected_cost_usd,
441 savings_usd: (b.baseline_cost_usd - b.projected_cost_usd).max(0.0),
442 })
443 .collect();
444 rows.sort_by_key(|r| r.route_id);
447 rows
448}
449
450fn build_caveats(
451 sample_size: usize,
452 requests_unprice_able: u32,
453 latency_unprojected: u32,
454) -> Vec<String> {
455 let mut caveats = Vec::new();
456 if sample_size < 1000 {
457 caveats.push(format!(
458 "Small sample size ({sample_size} requests) — confidence intervals are wide."
459 ));
460 }
461 if requests_unprice_able > 0 {
462 caveats.push(format!(
463 "{requests_unprice_able} request(s) routed to a target model with no pricing entry — counted as unchanged."
464 ));
465 }
466 if latency_unprojected > 0 {
467 caveats.push(format!(
468 "{latency_unprojected} rerouted request(s) had no latency history for the target model — their latency is shown unchanged, not projected."
469 ));
470 }
471 caveats
472}
473
474pub(crate) fn wide_ci_caveats(aggregates: &Aggregates, cis: &ConfidenceIntervals) -> Vec<String> {
478 let mut out = Vec::new();
479 let rel_width = |lo: f64, hi: f64, center: f64| -> Option<f64> {
480 if center.abs() < f64::EPSILON {
481 return None;
482 }
483 Some((hi - lo).abs() / center.abs())
484 };
485 if let Some(w) = rel_width(
486 cis.savings_usd_95.0,
487 cis.savings_usd_95.1,
488 aggregates.projected_savings_usd,
489 ) {
490 if w > 0.30 {
491 out.push(format!(
492 "Savings CI is wide: ±{:.0}% relative width. Treat the headline savings number as a rough estimate; consider scanning a larger window.",
493 w * 100.0
494 ));
495 }
496 }
497 if let Some(w) = rel_width(
498 cis.p50_latency_ms_95.0,
499 cis.p50_latency_ms_95.1,
500 aggregates.p50_latency_ms_projected,
501 ) {
502 if w > 0.30 {
503 out.push(format!(
504 "p50 latency CI is wide: ±{:.0}% relative width.",
505 w * 100.0
506 ));
507 }
508 }
509 out
510}
511
512fn percentile(values: &[f64], q: f64) -> f64 {
513 if values.is_empty() {
514 return 0.0;
515 }
516 let mut v = values.to_vec();
517 v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
518 let idx = ((q * (v.len() as f64 - 1.0)).round() as usize).min(v.len() - 1);
519 v[idx]
520}
521
522fn bootstrap_pct_savings_ci(
526 baseline: &[f64],
527 projected: &[f64],
528 seed: u64,
529 iterations: u32,
530) -> (f64, f64) {
531 use rand::Rng;
532 use rand::SeedableRng;
533 use rand_chacha::ChaCha8Rng;
534
535 let n = baseline.len();
536 if n == 0 || iterations == 0 || n != projected.len() {
537 return (0.0, 0.0);
538 }
539 let mut rng = ChaCha8Rng::seed_from_u64(seed);
540 let mut pct_samples: Vec<f64> = Vec::with_capacity(iterations as usize);
541 for _ in 0..iterations {
542 let mut b_sum = 0.0;
543 let mut p_sum = 0.0;
544 for _ in 0..n {
545 let idx = rng.gen_range(0..n);
546 b_sum += baseline[idx];
547 p_sum += projected[idx];
548 }
549 let pct = if b_sum > 0.0 {
550 (b_sum - p_sum) / b_sum * 100.0
551 } else {
552 0.0
553 };
554 pct_samples.push(pct.max(0.0));
555 }
556 pct_samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
557 let iter_f = iterations as f64;
558 let lo_idx = ((0.025 * iter_f) as usize).min(pct_samples.len() - 1);
559 let hi_idx = ((0.975 * iter_f) as usize).min(pct_samples.len() - 1);
560 (pct_samples[lo_idx], pct_samples[hi_idx])
561}