1use std::collections::HashMap;
6
7use uuid::Uuid;
8
9use crate::{
10 bootstrap, cost,
11 error::PlanError,
12 routing,
13 types::{
14 Aggregates, ConfidenceIntervals, PerRouteBreakdown, PlanInput, PlanResult, ProposedRoute,
15 RequestLog,
16 },
17};
18
19pub fn replay(input: PlanInput) -> Result<PlanResult, PlanError> {
32 validate(&input)?;
33
34 let mut routes = input.proposed_routes.clone();
42 routes.sort_by(|a, b| b.priority.cmp(&a.priority).then_with(|| a.id.cmp(&b.id)));
43
44 let mut requests = input.requests.clone();
46 requests.sort_by_key(|r| r.id);
47
48 let cache_hit_ids = crate::cache_projection::project_l1_hit_ids(&requests, &input.config);
54 let projection = project_requests(&requests, &routes, &input.pricing, &cache_hit_ids);
55
56 let mut aggregates = aggregate(&projection);
57
58 if input.config.l1_ttl_seconds.is_some() {
63 let proj = crate::cache_projection::project_l1_hits(&requests, &input.config);
64 aggregates.cache_hit_rate_projected = proj.projected_l1_hit_rate;
65 }
66
67 if !requests.is_empty() && requests.iter().any(|r| r.embedding.is_some()) {
73 let l2 = crate::l2_projection::project_l2_hits(&requests, &input.config);
74 aggregates.l2_projections = l2.per_threshold;
75 aggregates.l2_poisoning_candidates = l2.poisoning_candidates;
76 }
77
78 let confidence_intervals = compute_cis(&projection, input.seed, input.bootstrap_iterations);
79 let per_route_breakdown = build_per_route(projection.per_route);
80
81 let proposed_routes = input.proposed_routes;
88
89 let mut caveats = build_caveats(
90 requests.len(),
91 aggregates.requests_unprice_able,
92 projection.latency_unprojected,
93 projection.would_block,
94 );
95 caveats.extend(wide_ci_caveats(&aggregates, &confidence_intervals));
96
97 Ok(PlanResult {
98 plan_id: input.plan_id,
99 org_id: input.org_id,
100 window_start: input.window_start,
101 window_end: input.window_end,
102 sample_size: requests.len() as u32,
103 aggregates,
104 confidence_intervals,
105 per_route_breakdown,
106 caveats,
107 quality: None,
111 proposed_routes,
112 })
113}
114
115pub async fn replay_with_quality<F>(
130 input: PlanInput,
131 judge: &dyn crate::quality::JudgeProvider,
132 quality_config: &crate::quality::QualityConfig,
133 proposed_response_for: F,
134) -> Result<PlanResult, ReplayWithQualityError>
135where
136 F: Fn(&Uuid) -> Option<String>,
137{
138 let requests = input.requests.clone();
141 let mut result = replay(input).map_err(ReplayWithQualityError::Replay)?;
142 let quality =
143 crate::quality::score_quality(&requests, quality_config, judge, proposed_response_for)
144 .await
145 .map_err(ReplayWithQualityError::Quality)?;
146 result.quality = Some(quality);
147 Ok(result)
148}
149
150#[derive(Debug, thiserror::Error)]
154pub enum ReplayWithQualityError {
155 #[error("replay: {0}")]
157 Replay(#[from] crate::error::PlanError),
158 #[error("quality: {0}")]
161 Quality(#[from] crate::quality::QualityError),
162}
163
164fn validate(input: &PlanInput) -> Result<(), PlanError> {
165 if input.window_end <= input.window_start {
166 return Err(PlanError::InvalidWindow {
167 start: input.window_start.to_rfc3339(),
168 end: input.window_end.to_rfc3339(),
169 });
170 }
171 if input.bootstrap_iterations == 0 {
172 return Err(PlanError::ZeroBootstrapIterations);
173 }
174 Ok(())
175}
176
177struct PerRouteBucket {
179 route_id: Uuid,
180 route_name: String,
181 matched: u32,
182 baseline_cost_usd: f64,
183 projected_cost_usd: f64,
184}
185
186struct Projection {
190 per_request_baseline: Vec<f64>,
191 per_request_projected: Vec<f64>,
192 per_request_latency: Vec<f64>,
193 per_request_cache_hit: Vec<f64>,
194 per_route: HashMap<Uuid, PerRouteBucket>,
195 requests_rerouted: u32,
196 requests_unchanged: u32,
197 requests_unprice_able: u32,
198 latency_unprojected: u32,
201 would_block: u32,
204}
205
206fn project_requests(
207 requests: &[RequestLog],
208 routes: &[ProposedRoute],
209 pricing: &crate::types::PricingTable,
210 cache_hit_ids: &std::collections::HashSet<Uuid>,
211) -> Projection {
212 let cap = requests.len();
213 let mut per_request_baseline = Vec::with_capacity(cap);
214 let mut per_request_projected = Vec::with_capacity(cap);
215 let mut per_request_latency = Vec::with_capacity(cap);
216 let mut per_request_cache_hit = Vec::with_capacity(cap);
217 let mut per_route: HashMap<Uuid, PerRouteBucket> = HashMap::new();
218 let mut requests_rerouted: u32 = 0;
219 let mut requests_unchanged: u32 = 0;
220 let mut requests_unprice_able: u32 = 0;
221 let mut latency_unprojected: u32 = 0;
222 let mut would_block: u32 = 0;
223
224 let model_medians = model_median_latencies(requests);
228
229 let model_to_provider: HashMap<&str, &str> = {
235 let mut keys: Vec<&str> = pricing.keys().map(String::as_str).collect();
236 keys.sort_unstable();
237 let mut m: HashMap<&str, &str> = HashMap::new();
238 for k in keys {
239 if let Some((prov, model)) = k.split_once(':') {
240 m.entry(model).or_insert(prov);
241 }
242 }
243 m
244 };
245
246 for req in requests {
247 per_request_baseline.push(req.baseline_cost_usd);
248 per_request_cache_hit.push(if req.cached { 1.0 } else { 0.0 });
249
250 let is_cache_hit = cache_hit_ids.contains(&req.id);
253
254 let matched = routing::match_route(req, routes);
255 match matched {
256 Some(route) => {
257 let same_provider_key =
261 crate::types::pricing_key(&req.provider, &route.then.target_model);
262 let target_key = if pricing.contains_key(&same_provider_key) {
263 same_provider_key
264 } else {
265 let target_provider = model_to_provider
266 .get(route.then.target_model.as_str())
267 .copied()
268 .unwrap_or(req.provider.as_str());
269 crate::types::pricing_key(target_provider, &route.then.target_model)
270 };
271 if let Some(p) = pricing.get(&target_key) {
272 let projected = cost::project_cost(req, &route.then.target_model, p);
273 let mut projected_cost = if is_cache_hit {
274 0.0
275 } else {
276 projected.cost_usd
277 };
278 if !is_cache_hit
283 && route
284 .then
285 .max_cost_usd
286 .is_some_and(|c| projected.cost_usd > c)
287 {
288 projected_cost = req.cost_usd;
289 would_block += 1;
290 }
291 per_request_projected.push(projected_cost);
292 match model_medians.get(route.then.target_model.as_str()) {
296 Some(&med) => per_request_latency.push(med),
297 None => {
298 per_request_latency.push(f64::from(req.latency_ms));
299 latency_unprojected += 1;
300 }
301 }
302 let bucket = per_route.entry(route.id).or_insert_with(|| PerRouteBucket {
303 route_id: route.id,
304 route_name: route.name.clone(),
305 matched: 0,
306 baseline_cost_usd: 0.0,
307 projected_cost_usd: 0.0,
308 });
309 bucket.matched += 1;
310 bucket.baseline_cost_usd += req.baseline_cost_usd;
311 bucket.projected_cost_usd += projected_cost;
312 requests_rerouted += 1;
313 } else {
314 per_request_projected.push(if is_cache_hit { 0.0 } else { req.cost_usd });
317 per_request_latency.push(f64::from(req.latency_ms));
318 requests_unprice_able += 1;
319 }
320 }
321 None => {
322 per_request_projected.push(if is_cache_hit { 0.0 } else { req.cost_usd });
323 per_request_latency.push(f64::from(req.latency_ms));
324 requests_unchanged += 1;
325 }
326 }
327 }
328
329 Projection {
330 per_request_baseline,
331 per_request_projected,
332 per_request_latency,
333 per_request_cache_hit,
334 per_route,
335 requests_rerouted,
336 requests_unchanged,
337 requests_unprice_able,
338 latency_unprojected,
339 would_block,
340 }
341}
342
343fn model_median_latencies(requests: &[RequestLog]) -> HashMap<&str, f64> {
347 let mut by_model: HashMap<&str, Vec<u32>> = HashMap::new();
348 for r in requests {
349 by_model
350 .entry(r.model.as_str())
351 .or_default()
352 .push(r.latency_ms);
353 }
354 by_model
355 .into_iter()
356 .map(|(model, mut lat)| {
357 lat.sort_unstable();
358 (model, f64::from(lat[lat.len() / 2]))
359 })
360 .collect()
361}
362
363fn aggregate(p: &Projection) -> Aggregates {
364 let total_baseline: f64 = p.per_request_baseline.iter().sum();
365 let total_projected: f64 = p.per_request_projected.iter().sum();
366 let projected_savings = (total_baseline - total_projected).max(0.0);
367 let projected_savings_pct = if total_baseline > 0.0 {
368 projected_savings / total_baseline * 100.0
369 } else {
370 0.0
371 };
372 let cache_hit_rate = if p.per_request_cache_hit.is_empty() {
373 0.0
374 } else {
375 p.per_request_cache_hit.iter().sum::<f64>() / p.per_request_cache_hit.len() as f64
376 };
377 let p50_latency = percentile(&p.per_request_latency, 0.50);
378 let p95_latency = percentile(&p.per_request_latency, 0.95);
379
380 Aggregates {
381 total_baseline_cost_usd: total_baseline,
382 total_projected_cost_usd: total_projected,
383 projected_savings_usd: projected_savings,
384 projected_savings_pct,
385 cache_hit_rate_projected: cache_hit_rate,
386 p50_latency_ms_projected: p50_latency,
387 p95_latency_ms_projected: p95_latency,
388 requests_rerouted: p.requests_rerouted,
389 requests_unchanged: p.requests_unchanged,
390 requests_unprice_able: p.requests_unprice_able,
391 l2_projections: Vec::new(),
394 l2_poisoning_candidates: 0,
395 }
396}
397
398fn compute_cis(p: &Projection, seed: u64, iterations: u32) -> ConfidenceIntervals {
399 let n = p.per_request_baseline.len() as f64;
403 let savings_per_req: Vec<f64> = p
404 .per_request_baseline
405 .iter()
406 .zip(p.per_request_projected.iter())
407 .map(|(b, pr)| (b - pr).max(0.0))
408 .collect();
409 let (sv_lo_mean, sv_hi_mean) =
410 bootstrap::bootstrap_ci(&savings_per_req, seed, iterations, (0.025, 0.975));
411 let savings_usd_95 = (sv_lo_mean * n, sv_hi_mean * n);
412
413 let savings_pct_95 = bootstrap_pct_savings_ci(
416 &p.per_request_baseline,
417 &p.per_request_projected,
418 seed.wrapping_add(1),
419 iterations,
420 );
421
422 let cache_hit_rate_95 = bootstrap::bootstrap_ci(
425 &p.per_request_cache_hit,
426 seed.wrapping_add(2),
427 iterations,
428 (0.025, 0.975),
429 );
430
431 let p50_latency_ms_95 = bootstrap_percentile_ci(
434 &p.per_request_latency,
435 0.50,
436 seed.wrapping_add(3),
437 iterations,
438 );
439 let p95_latency_ms_95 = bootstrap_percentile_ci(
440 &p.per_request_latency,
441 0.95,
442 seed.wrapping_add(4),
443 iterations,
444 );
445
446 ConfidenceIntervals {
447 savings_usd_95,
448 savings_pct_95,
449 cache_hit_rate_95,
450 p50_latency_ms_95,
451 p95_latency_ms_95,
452 }
453}
454
455fn bootstrap_percentile_ci(values: &[f64], q: f64, seed: u64, iterations: u32) -> (f64, f64) {
459 use rand::{Rng, SeedableRng};
460 use rand_chacha::ChaCha8Rng;
461 if values.is_empty() || iterations == 0 {
462 return (0.0, 0.0);
463 }
464 let n = values.len();
465 let mut rng = ChaCha8Rng::seed_from_u64(seed);
466 let mut samples: Vec<f64> = Vec::with_capacity(iterations as usize);
467 let mut buf: Vec<f64> = Vec::with_capacity(n);
468 for _ in 0..iterations {
469 buf.clear();
470 for _ in 0..n {
471 buf.push(values[rng.gen_range(0..n)]);
472 }
473 samples.push(percentile(&buf, q));
474 }
475 samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
476 let lo_idx = (0.025 * iterations as f64) as usize;
477 let hi_idx = ((0.975 * iterations as f64) as usize).min(iterations as usize - 1);
478 (samples[lo_idx], samples[hi_idx])
479}
480
481fn build_per_route(buckets: HashMap<Uuid, PerRouteBucket>) -> Vec<PerRouteBreakdown> {
482 let mut rows: Vec<PerRouteBreakdown> = buckets
483 .into_values()
484 .map(|b| PerRouteBreakdown {
485 route_id: b.route_id,
486 route_name: b.route_name,
487 matched: b.matched,
488 baseline_cost_usd: b.baseline_cost_usd,
489 projected_cost_usd: b.projected_cost_usd,
490 savings_usd: (b.baseline_cost_usd - b.projected_cost_usd).max(0.0),
491 })
492 .collect();
493 rows.sort_by_key(|r| r.route_id);
496 rows
497}
498
499fn build_caveats(
500 sample_size: usize,
501 requests_unprice_able: u32,
502 latency_unprojected: u32,
503 would_block: u32,
504) -> Vec<String> {
505 let mut caveats = Vec::new();
506 if sample_size < 1000 {
507 caveats.push(format!(
508 "Small sample size ({sample_size} requests) — confidence intervals are wide."
509 ));
510 }
511 if requests_unprice_able > 0 {
512 caveats.push(format!(
513 "{requests_unprice_able} request(s) routed to a target model with no pricing entry — counted as unchanged."
514 ));
515 }
516 if latency_unprojected > 0 {
517 caveats.push(format!(
518 "{latency_unprojected} rerouted request(s) had no latency history for the target model — their latency is shown unchanged, not projected."
519 ));
520 }
521 if would_block > 0 {
522 caveats.push(format!(
523 "{would_block} request(s) would be rejected by a max_cost_usd ceiling — counted unchanged, not as savings."
524 ));
525 }
526 caveats
527}
528
529pub(crate) fn wide_ci_caveats(aggregates: &Aggregates, cis: &ConfidenceIntervals) -> Vec<String> {
533 let mut out = Vec::new();
534 let rel_width = |lo: f64, hi: f64, center: f64| -> Option<f64> {
535 if center.abs() < f64::EPSILON {
536 return None;
537 }
538 Some((hi - lo).abs() / center.abs())
539 };
540 if let Some(w) = rel_width(
541 cis.savings_usd_95.0,
542 cis.savings_usd_95.1,
543 aggregates.projected_savings_usd,
544 ) {
545 if w > 0.30 {
546 out.push(format!(
547 "Savings CI is wide: ±{:.0}% relative width. Treat the headline savings number as a rough estimate; consider scanning a larger window.",
548 w * 100.0
549 ));
550 }
551 }
552 if let Some(w) = rel_width(
553 cis.p50_latency_ms_95.0,
554 cis.p50_latency_ms_95.1,
555 aggregates.p50_latency_ms_projected,
556 ) {
557 if w > 0.30 {
558 out.push(format!(
559 "p50 latency CI is wide: ±{:.0}% relative width.",
560 w * 100.0
561 ));
562 }
563 }
564 out
565}
566
567fn percentile(values: &[f64], q: f64) -> f64 {
568 if values.is_empty() {
569 return 0.0;
570 }
571 let mut v = values.to_vec();
572 v.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
573 let idx = ((q * (v.len() as f64 - 1.0)).round() as usize).min(v.len() - 1);
574 v[idx]
575}
576
577fn bootstrap_pct_savings_ci(
581 baseline: &[f64],
582 projected: &[f64],
583 seed: u64,
584 iterations: u32,
585) -> (f64, f64) {
586 use rand::Rng;
587 use rand::SeedableRng;
588 use rand_chacha::ChaCha8Rng;
589
590 let n = baseline.len();
591 if n == 0 || iterations == 0 || n != projected.len() {
592 return (0.0, 0.0);
593 }
594 let mut rng = ChaCha8Rng::seed_from_u64(seed);
595 let mut pct_samples: Vec<f64> = Vec::with_capacity(iterations as usize);
596 for _ in 0..iterations {
597 let mut b_sum = 0.0;
598 let mut p_sum = 0.0;
599 for _ in 0..n {
600 let idx = rng.gen_range(0..n);
601 b_sum += baseline[idx];
602 p_sum += projected[idx];
603 }
604 let pct = if b_sum > 0.0 {
605 (b_sum - p_sum) / b_sum * 100.0
606 } else {
607 0.0
608 };
609 pct_samples.push(pct.max(0.0));
610 }
611 pct_samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
612 let iter_f = iterations as f64;
613 let lo_idx = ((0.025 * iter_f) as usize).min(pct_samples.len() - 1);
614 let hi_idx = ((0.975 * iter_f) as usize).min(pct_samples.len() - 1);
615 (pct_samples[lo_idx], pct_samples[hi_idx])
616}