Skip to main content

tarn/
bench.rs

1use crate::error::TarnError;
2use crate::http;
3use crate::interpolation::{self, Context};
4use crate::model::{AuthConfig, HttpTransportConfig, HttpVersionPreference, Step, TestFile};
5use base64::Engine;
6use indexmap::IndexMap;
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::sync::Semaphore;
12
13/// Options for a benchmark run.
14#[derive(Debug, Clone)]
15pub struct BenchOptions {
16    /// Total number of requests to send
17    pub requests: u64,
18    /// Number of concurrent workers
19    pub concurrency: u64,
20    /// Ramp-up duration (gradually add workers)
21    pub ramp_up: Option<Duration>,
22    /// Optional CI-style threshold checks
23    pub thresholds: BenchThresholds,
24}
25
26#[derive(Debug, Clone, Default)]
27pub struct BenchThresholds {
28    pub min_throughput_rps: Option<f64>,
29    pub max_error_rate: Option<f64>,
30    pub max_p95_ms: Option<u64>,
31    pub max_p99_ms: Option<u64>,
32}
33
34/// Result of a single request in the benchmark.
35#[derive(Debug, Clone)]
36struct RequestResult {
37    status: u16,
38    success: bool,
39    error: Option<String>,
40    timings: RequestTimings,
41}
42
43#[derive(Debug, Clone)]
44struct RequestTimings {
45    total_ms: u64,
46    ttfb_ms: u64,
47    body_read_ms: u64,
48}
49
50#[derive(Debug, Clone)]
51enum BenchPayload {
52    Json(serde_json::Value),
53    Form(IndexMap<String, String>),
54}
55
56fn effective_auth<'a>(step: &'a Step, test_file: &'a TestFile) -> Option<&'a AuthConfig> {
57    step.request.auth.as_ref().or_else(|| {
58        test_file
59            .defaults
60            .as_ref()
61            .and_then(|defaults| defaults.auth.as_ref())
62    })
63}
64
65fn apply_auth_header(
66    headers: &mut HashMap<String, String>,
67    auth: Option<&AuthConfig>,
68    ctx: &Context,
69) {
70    if headers
71        .keys()
72        .any(|key| key.eq_ignore_ascii_case("authorization"))
73    {
74        return;
75    }
76
77    let Some(auth) = auth else {
78        return;
79    };
80
81    if let Some(token) = auth.bearer.as_ref() {
82        headers.insert(
83            "Authorization".into(),
84            format!("Bearer {}", interpolation::interpolate(token, ctx)),
85        );
86    } else if let Some(basic) = auth.basic.as_ref() {
87        let username = interpolation::interpolate(&basic.username, ctx);
88        let password = interpolation::interpolate(&basic.password, ctx);
89        let encoded =
90            base64::engine::general_purpose::STANDARD.encode(format!("{username}:{password}"));
91        headers.insert("Authorization".into(), format!("Basic {encoded}"));
92    }
93}
94
95/// Aggregated benchmark results.
96#[derive(Debug, Clone, serde::Serialize)]
97pub struct BenchResult {
98    pub step_name: String,
99    pub method: String,
100    pub url: String,
101    pub concurrency: u64,
102    pub ramp_up_ms: Option<u64>,
103    pub total_requests: u64,
104    pub successful: u64,
105    pub failed: u64,
106    pub error_rate: f64,
107    pub total_duration_ms: u64,
108    pub throughput_rps: f64,
109    pub latency: LatencyStats,
110    pub timings: TimingBreakdown,
111    pub status_codes: HashMap<u16, u64>,
112    pub errors: Vec<String>,
113    pub gates: Vec<BenchGateResult>,
114    pub passed_gates: bool,
115}
116
117#[derive(Debug, Clone, serde::Serialize)]
118pub struct LatencyStats {
119    pub min_ms: u64,
120    pub max_ms: u64,
121    pub mean_ms: f64,
122    pub median_ms: u64,
123    pub p95_ms: u64,
124    pub p99_ms: u64,
125    pub stdev_ms: f64,
126}
127
128#[derive(Debug, Clone, serde::Serialize)]
129pub struct TimingBreakdown {
130    pub total: LatencyStats,
131    pub ttfb: LatencyStats,
132    pub body_read: LatencyStats,
133    pub connect: Option<LatencyStats>,
134    pub tls: Option<LatencyStats>,
135}
136
137#[derive(Debug, Clone, serde::Serialize)]
138pub struct BenchGateResult {
139    pub name: String,
140    pub passed: bool,
141    pub expected: String,
142    pub actual: String,
143    pub message: String,
144}
145
146/// Run a benchmark against a single step from a test file.
147pub fn run_bench(
148    test_file: &TestFile,
149    step_index: usize,
150    env: &HashMap<String, String>,
151    opts: &BenchOptions,
152    http_config: &HttpTransportConfig,
153) -> Result<BenchResult, TarnError> {
154    let step = resolve_step(test_file, step_index)?;
155
156    // Build interpolation context
157    let ctx = Context {
158        env: env.clone(),
159        captures: HashMap::new(),
160        optional_unset: std::collections::HashSet::new(),
161    };
162
163    // Interpolate the request once (captures won't work in bench mode)
164    let url = interpolation::interpolate(&step.request.url, &ctx);
165    let mut merged_headers = test_file
166        .defaults
167        .as_ref()
168        .map(|d| d.headers.clone())
169        .unwrap_or_default();
170    for (k, v) in &step.request.headers {
171        merged_headers.insert(k.clone(), v.clone());
172    }
173    apply_auth_header(&mut merged_headers, effective_auth(step, test_file), &ctx);
174    let payload = if let Some(ref form) = step.request.form {
175        let form = interpolation::interpolate_string_map(form, &ctx);
176        merged_headers
177            .entry("Content-Type".to_string())
178            .or_insert_with(|| "application/x-www-form-urlencoded".to_string());
179        Some(BenchPayload::Form(form))
180    } else {
181        step.request
182            .body
183            .as_ref()
184            .map(|b| BenchPayload::Json(interpolation::interpolate_json(b, &ctx)))
185    };
186    let headers = interpolation::interpolate_headers(&merged_headers, &ctx);
187
188    let method = step.request.method.clone();
189    let step_name = step.name.clone();
190
191    // Expected status from assertions (if any) — extract exact status for bench mode
192    let expected_status = step.assertions.as_ref().and_then(|a| {
193        a.status.as_ref().and_then(|s| match s {
194            crate::model::StatusAssertion::Exact(code) => Some(*code),
195            _ => None, // Bench mode only supports exact status checks
196        })
197    });
198
199    let bench_req = BenchRequest {
200        step_name: &step_name,
201        method: &method,
202        url: &url,
203        headers: &headers,
204        payload: payload.as_ref(),
205        expected_status,
206    };
207
208    // Run the benchmark using tokio runtime
209    let rt = tokio::runtime::Runtime::new()
210        .map_err(|e| TarnError::Http(format!("Failed to create async runtime: {}", e)))?;
211
212    let result = rt.block_on(run_bench_async(&bench_req, opts, http_config))?;
213
214    Ok(result)
215}
216
217struct BenchRequest<'a> {
218    step_name: &'a str,
219    method: &'a str,
220    url: &'a str,
221    headers: &'a HashMap<String, String>,
222    payload: Option<&'a BenchPayload>,
223    expected_status: Option<u16>,
224}
225
226async fn run_bench_async(
227    req: &BenchRequest<'_>,
228    opts: &BenchOptions,
229    http_config: &HttpTransportConfig,
230) -> Result<BenchResult, TarnError> {
231    let semaphore = Arc::new(Semaphore::new(opts.concurrency as usize));
232    let completed = Arc::new(AtomicU64::new(0));
233
234    let client = http::build_async_client_with_timeout(http_config, Some(Duration::from_secs(30)))?;
235
236    let overall_start = Instant::now();
237    let http_version = http_config.http_version;
238
239    let mut handles = Vec::with_capacity(opts.requests as usize);
240
241    for i in 0..opts.requests {
242        // Ramp-up: stagger initial requests
243        if let Some(ramp) = opts.ramp_up {
244            if i < opts.concurrency {
245                let delay_per_worker = ramp / opts.concurrency as u32;
246                let delay = delay_per_worker * i as u32;
247                tokio::time::sleep(delay).await;
248            }
249        }
250
251        let permit = semaphore.clone().acquire_owned().await.unwrap();
252        let client = client.clone();
253        let method = req.method.to_string();
254        let url = req.url.to_string();
255        let headers = req.headers.clone();
256        let payload = req.payload.cloned();
257        let completed = completed.clone();
258        let expected_status = req.expected_status;
259
260        let handle = tokio::spawn(async move {
261            let result = execute_single(
262                &client,
263                &method,
264                &url,
265                &headers,
266                payload.as_ref(),
267                expected_status,
268                http_version,
269            )
270            .await;
271            completed.fetch_add(1, Ordering::Relaxed);
272            drop(permit);
273            result
274        });
275
276        handles.push(handle);
277    }
278
279    // Collect results
280    let mut results = Vec::with_capacity(opts.requests as usize);
281    for handle in handles {
282        match handle.await {
283            Ok(r) => results.push(r),
284            Err(e) => results.push(RequestResult {
285                status: 0,
286                success: false,
287                error: Some(format!("Task failed: {}", e)),
288                timings: RequestTimings {
289                    total_ms: 0,
290                    ttfb_ms: 0,
291                    body_read_ms: 0,
292                },
293            }),
294        }
295    }
296
297    let total_duration_ms = overall_start.elapsed().as_millis() as u64;
298
299    let mut result = aggregate_results(
300        req.step_name,
301        req.method,
302        req.url,
303        opts.concurrency,
304        opts.ramp_up,
305        results,
306        total_duration_ms,
307    );
308    result.gates = evaluate_gates(&result, &opts.thresholds);
309    result.passed_gates = result.gates.iter().all(|gate| gate.passed);
310
311    Ok(result)
312}
313
314async fn execute_single(
315    client: &reqwest::Client,
316    method: &str,
317    url: &str,
318    headers: &HashMap<String, String>,
319    payload: Option<&BenchPayload>,
320    expected_status: Option<u16>,
321    http_version: Option<HttpVersionPreference>,
322) -> RequestResult {
323    let req_method = match reqwest::Method::from_bytes(method.trim().as_bytes()) {
324        Ok(method) => method,
325        Err(error) => {
326            return RequestResult {
327                status: 0,
328                success: false,
329                error: Some(format!("Invalid HTTP method '{}': {}", method, error)),
330                timings: RequestTimings {
331                    total_ms: 0,
332                    ttfb_ms: 0,
333                    body_read_ms: 0,
334                },
335            }
336        }
337    };
338
339    let mut builder = client.request(req_method, url);
340
341    builder = match http_version {
342        Some(HttpVersionPreference::Http1_1) => builder.version(reqwest::Version::HTTP_11),
343        Some(HttpVersionPreference::Http2) => builder.version(reqwest::Version::HTTP_2),
344        None => builder,
345    };
346
347    for (k, v) in headers {
348        builder = builder.header(k, v);
349    }
350
351    if let Some(payload) = payload {
352        builder = match payload {
353            BenchPayload::Json(body) => builder.json(body),
354            BenchPayload::Form(form) => match http::encode_form_body(form) {
355                Ok(body) => builder.body(body),
356                Err(error) => {
357                    return RequestResult {
358                        status: 0,
359                        success: false,
360                        error: Some(error.to_string()),
361                        timings: RequestTimings {
362                            total_ms: 0,
363                            ttfb_ms: 0,
364                            body_read_ms: 0,
365                        },
366                    }
367                }
368            },
369        };
370    }
371
372    let start = Instant::now();
373    match builder.send().await {
374        Ok(resp) => {
375            let status = resp.status().as_u16();
376            let ttfb_ms = start.elapsed().as_millis() as u64;
377            let body_start = Instant::now();
378            // Consume body to complete the request
379            let _ = resp.bytes().await;
380            let body_read_ms = body_start.elapsed().as_millis() as u64;
381            let duration_ms = ttfb_ms.saturating_add(body_read_ms);
382            let success = expected_status.map(|e| e == status).unwrap_or(true);
383            RequestResult {
384                status,
385                success,
386                error: if success {
387                    None
388                } else {
389                    Some(format!(
390                        "Expected status {}, got {}",
391                        expected_status.unwrap(),
392                        status
393                    ))
394                },
395                timings: RequestTimings {
396                    total_ms: duration_ms,
397                    ttfb_ms,
398                    body_read_ms,
399                },
400            }
401        }
402        Err(e) => RequestResult {
403            status: 0,
404            success: false,
405            error: Some(e.to_string()),
406            timings: RequestTimings {
407                total_ms: 0,
408                ttfb_ms: 0,
409                body_read_ms: 0,
410            },
411        },
412    }
413}
414
415fn aggregate_results(
416    step_name: &str,
417    method: &str,
418    url: &str,
419    concurrency: u64,
420    ramp_up: Option<Duration>,
421    results: Vec<RequestResult>,
422    total_duration_ms: u64,
423) -> BenchResult {
424    let total = results.len() as u64;
425    let successful = results.iter().filter(|r| r.success).count() as u64;
426    let failed = total - successful;
427
428    // Collect timings from successful requests
429    let mut total_latencies: Vec<u64> = results
430        .iter()
431        .filter(|r| r.success)
432        .map(|r| r.timings.total_ms)
433        .collect();
434    let mut ttfb_latencies: Vec<u64> = results
435        .iter()
436        .filter(|r| r.success)
437        .map(|r| r.timings.ttfb_ms)
438        .collect();
439    let mut body_read_latencies: Vec<u64> = results
440        .iter()
441        .filter(|r| r.success)
442        .map(|r| r.timings.body_read_ms)
443        .collect();
444    total_latencies.sort();
445    ttfb_latencies.sort();
446    body_read_latencies.sort();
447
448    let latency = summarize_latencies(&total_latencies);
449    let timings = TimingBreakdown {
450        total: latency.clone(),
451        ttfb: summarize_latencies(&ttfb_latencies),
452        body_read: summarize_latencies(&body_read_latencies),
453        connect: None,
454        tls: None,
455    };
456
457    // Status code distribution
458    let mut status_codes: HashMap<u16, u64> = HashMap::new();
459    for r in &results {
460        if r.status > 0 {
461            *status_codes.entry(r.status).or_insert(0) += 1;
462        }
463    }
464
465    // Unique errors (limit to 10)
466    let mut errors: Vec<String> = Vec::new();
467    for r in &results {
468        if let Some(ref e) = r.error {
469            if errors.len() < 10 && !errors.contains(e) {
470                errors.push(e.clone());
471            }
472        }
473    }
474
475    let throughput = if total_duration_ms > 0 {
476        (total as f64 / total_duration_ms as f64) * 1000.0
477    } else {
478        0.0
479    };
480
481    BenchResult {
482        step_name: step_name.to_string(),
483        method: method.to_string(),
484        url: url.to_string(),
485        concurrency,
486        ramp_up_ms: ramp_up.map(|duration| duration.as_millis() as u64),
487        total_requests: total,
488        successful,
489        failed,
490        error_rate: if total > 0 {
491            (failed as f64 / total as f64) * 100.0
492        } else {
493            0.0
494        },
495        total_duration_ms,
496        throughput_rps: (throughput * 100.0).round() / 100.0,
497        latency,
498        timings,
499        status_codes,
500        errors,
501        gates: Vec::new(),
502        passed_gates: true,
503    }
504}
505
506fn summarize_latencies(latencies: &[u64]) -> LatencyStats {
507    if latencies.is_empty() {
508        return LatencyStats {
509            min_ms: 0,
510            max_ms: 0,
511            mean_ms: 0.0,
512            median_ms: 0,
513            p95_ms: 0,
514            p99_ms: 0,
515            stdev_ms: 0.0,
516        };
517    }
518
519    let min = *latencies.first().unwrap();
520    let max = *latencies.last().unwrap();
521    let sum: u64 = latencies.iter().sum();
522    let mean = sum as f64 / latencies.len() as f64;
523    let median = percentile(latencies, 50.0);
524    let p95 = percentile(latencies, 95.0);
525    let p99 = percentile(latencies, 99.0);
526    let variance = latencies
527        .iter()
528        .map(|&value| {
529            let diff = value as f64 - mean;
530            diff * diff
531        })
532        .sum::<f64>()
533        / latencies.len() as f64;
534
535    LatencyStats {
536        min_ms: min,
537        max_ms: max,
538        mean_ms: (mean * 100.0).round() / 100.0,
539        median_ms: median,
540        p95_ms: p95,
541        p99_ms: p99,
542        stdev_ms: (variance.sqrt() * 100.0).round() / 100.0,
543    }
544}
545
546fn evaluate_gates(result: &BenchResult, thresholds: &BenchThresholds) -> Vec<BenchGateResult> {
547    let mut gates = Vec::new();
548
549    if let Some(min_rps) = thresholds.min_throughput_rps {
550        let passed = result.throughput_rps >= min_rps;
551        gates.push(BenchGateResult {
552            name: "throughput_rps".into(),
553            passed,
554            expected: format!(">= {:.2}", min_rps),
555            actual: format!("{:.2}", result.throughput_rps),
556            message: if passed {
557                "Throughput gate passed".into()
558            } else {
559                "Throughput dropped below the configured floor".into()
560            },
561        });
562    }
563
564    if let Some(max_error_rate) = thresholds.max_error_rate {
565        let passed = result.error_rate <= max_error_rate;
566        gates.push(BenchGateResult {
567            name: "error_rate".into(),
568            passed,
569            expected: format!("<= {:.2}", max_error_rate),
570            actual: format!("{:.2}", result.error_rate),
571            message: if passed {
572                "Error-rate gate passed".into()
573            } else {
574                "Error rate exceeded the configured ceiling".into()
575            },
576        });
577    }
578
579    if let Some(max_p95_ms) = thresholds.max_p95_ms {
580        let passed = result.latency.p95_ms <= max_p95_ms;
581        gates.push(BenchGateResult {
582            name: "latency_p95_ms".into(),
583            passed,
584            expected: format!("<= {}", max_p95_ms),
585            actual: result.latency.p95_ms.to_string(),
586            message: if passed {
587                "P95 latency gate passed".into()
588            } else {
589                "P95 latency exceeded the configured ceiling".into()
590            },
591        });
592    }
593
594    if let Some(max_p99_ms) = thresholds.max_p99_ms {
595        let passed = result.latency.p99_ms <= max_p99_ms;
596        gates.push(BenchGateResult {
597            name: "latency_p99_ms".into(),
598            passed,
599            expected: format!("<= {}", max_p99_ms),
600            actual: result.latency.p99_ms.to_string(),
601            message: if passed {
602                "P99 latency gate passed".into()
603            } else {
604                "P99 latency exceeded the configured ceiling".into()
605            },
606        });
607    }
608
609    gates
610}
611
612fn percentile(sorted: &[u64], pct: f64) -> u64 {
613    if sorted.is_empty() {
614        return 0;
615    }
616    let idx = ((pct / 100.0) * (sorted.len() - 1) as f64).round() as usize;
617    sorted[idx.min(sorted.len() - 1)]
618}
619
620fn resolve_step(test_file: &TestFile, step_index: usize) -> Result<&Step, TarnError> {
621    // Try flat steps first
622    if !test_file.steps.is_empty() {
623        return test_file.steps.get(step_index).ok_or_else(|| {
624            TarnError::Config(format!(
625                "Step index {} out of range (file has {} steps)",
626                step_index,
627                test_file.steps.len()
628            ))
629        });
630    }
631
632    // Then try first test group's steps
633    if let Some((_, group)) = test_file.tests.iter().next() {
634        return group.steps.get(step_index).ok_or_else(|| {
635            TarnError::Config(format!(
636                "Step index {} out of range (test has {} steps)",
637                step_index,
638                group.steps.len()
639            ))
640        });
641    }
642
643    Err(TarnError::Config("No steps found in test file".into()))
644}
645
646/// Render benchmark results as human-readable output.
647pub fn render_human(result: &BenchResult) -> String {
648    use colored::Colorize;
649
650    let mut out = String::new();
651
652    out.push_str(&format!(
653        "\n {} {} {} — {} requests, {} concurrent\n\n",
654        "TARN BENCH".bold().white().on_blue(),
655        result.method.bold(),
656        result.url.dimmed(),
657        result.total_requests,
658        result.concurrency,
659    ));
660
661    // Request summary
662    let ok_str = result.successful.to_string().green();
663    let fail_str = if result.failed > 0 {
664        result.failed.to_string().red()
665    } else {
666        result.failed.to_string().dimmed()
667    };
668    out.push_str(&format!(
669        "  {:<14} {} total, {} ok, {} failed ({:.1}%)\n",
670        "Requests:".bold(),
671        result.total_requests,
672        ok_str,
673        fail_str,
674        result.error_rate
675    ));
676
677    // Duration & throughput
678    let dur = if result.total_duration_ms >= 1000 {
679        format!("{:.2}s", result.total_duration_ms as f64 / 1000.0)
680    } else {
681        format!("{}ms", result.total_duration_ms)
682    };
683    out.push_str(&format!("  {:<14} {}\n", "Duration:".bold(), dur));
684    out.push_str(&format!(
685        "  {:<14} {:.1} req/s\n",
686        "Throughput:".bold(),
687        result.throughput_rps
688    ));
689
690    // Latency
691    out.push_str(&format!("\n  {}:\n", "Latency".bold()));
692    out.push_str(&format!("    {:<10} {}ms\n", "min", result.latency.min_ms));
693    out.push_str(&format!(
694        "    {:<10} {}ms\n",
695        "p50", result.latency.median_ms
696    ));
697    out.push_str(&format!(
698        "    {:<10} {}ms\n",
699        "p95".yellow(),
700        result.latency.p95_ms.to_string().yellow()
701    ));
702    out.push_str(&format!(
703        "    {:<10} {}ms\n",
704        "p99".red(),
705        result.latency.p99_ms.to_string().red()
706    ));
707    out.push_str(&format!("    {:<10} {}ms\n", "max", result.latency.max_ms));
708    out.push_str(&format!(
709        "    {:<10} {:.2}ms\n",
710        "stdev", result.latency.stdev_ms
711    ));
712
713    out.push_str(&format!("\n  {}:\n", "Timings".bold()));
714    out.push_str(&format!(
715        "    {:<10} p50={}ms p95={}ms p99={}ms\n",
716        "ttfb",
717        result.timings.ttfb.median_ms,
718        result.timings.ttfb.p95_ms,
719        result.timings.ttfb.p99_ms
720    ));
721    out.push_str(&format!(
722        "    {:<10} p50={}ms p95={}ms p99={}ms\n",
723        "body-read",
724        result.timings.body_read.median_ms,
725        result.timings.body_read.p95_ms,
726        result.timings.body_read.p99_ms
727    ));
728    out.push_str("    connect    n/a (reqwest client does not expose phase timing)\n");
729    out.push_str("    tls        n/a (reqwest client does not expose phase timing)\n");
730
731    // Status codes
732    if !result.status_codes.is_empty() {
733        out.push_str(&format!("\n  {}:\n", "Status codes".bold()));
734        let mut codes: Vec<_> = result.status_codes.iter().collect();
735        codes.sort_by_key(|(code, _)| *code);
736        for (code, count) in codes {
737            let code_str = if *code >= 200 && *code < 300 {
738                code.to_string().green().to_string()
739            } else if *code >= 400 {
740                code.to_string().red().to_string()
741            } else {
742                code.to_string()
743            };
744            out.push_str(&format!("    {} — {} responses\n", code_str, count));
745        }
746    }
747
748    // Errors
749    if !result.errors.is_empty() {
750        out.push_str(&format!("\n  {}:\n", "Errors".bold().red()));
751        for e in &result.errors {
752            out.push_str(&format!("    - {}\n", e.red()));
753        }
754    }
755
756    if !result.gates.is_empty() {
757        out.push_str(&format!("\n  {}:\n", "CI gates".bold()));
758        for gate in &result.gates {
759            let status = if gate.passed {
760                "PASS".green().to_string()
761            } else {
762                "FAIL".red().to_string()
763            };
764            out.push_str(&format!(
765                "    [{}] {} expected {} actual {}\n",
766                status, gate.name, gate.expected, gate.actual
767            ));
768        }
769    }
770
771    out.push('\n');
772    out
773}
774
775/// Render benchmark results as JSON.
776pub fn render_json(result: &BenchResult) -> String {
777    serde_json::to_string_pretty(result).unwrap_or_else(|_| "{}".to_string())
778}
779
780/// Render benchmark results as a single-row CSV summary.
781pub fn render_csv(result: &BenchResult) -> String {
782    let mut lines = vec![[
783        "step_name",
784        "method",
785        "url",
786        "concurrency",
787        "requests",
788        "successful",
789        "failed",
790        "error_rate",
791        "throughput_rps",
792        "latency_p50_ms",
793        "latency_p95_ms",
794        "latency_p99_ms",
795        "ttfb_p50_ms",
796        "ttfb_p95_ms",
797        "ttfb_p99_ms",
798        "body_read_p50_ms",
799        "body_read_p95_ms",
800        "body_read_p99_ms",
801        "passed_gates",
802    ]
803    .join(",")];
804
805    lines.push(
806        vec![
807            csv_escape(&result.step_name),
808            csv_escape(&result.method),
809            csv_escape(&result.url),
810            result.concurrency.to_string(),
811            result.total_requests.to_string(),
812            result.successful.to_string(),
813            result.failed.to_string(),
814            format!("{:.2}", result.error_rate),
815            format!("{:.2}", result.throughput_rps),
816            result.latency.median_ms.to_string(),
817            result.latency.p95_ms.to_string(),
818            result.latency.p99_ms.to_string(),
819            result.timings.ttfb.median_ms.to_string(),
820            result.timings.ttfb.p95_ms.to_string(),
821            result.timings.ttfb.p99_ms.to_string(),
822            result.timings.body_read.median_ms.to_string(),
823            result.timings.body_read.p95_ms.to_string(),
824            result.timings.body_read.p99_ms.to_string(),
825            result.passed_gates.to_string(),
826        ]
827        .join(","),
828    );
829
830    lines.join("\n") + "\n"
831}
832
833fn csv_escape(value: &str) -> String {
834    if value.contains([',', '"', '\n']) {
835        format!("\"{}\"", value.replace('"', "\"\""))
836    } else {
837        value.to_string()
838    }
839}
840
841#[cfg(test)]
842mod tests {
843    use super::*;
844
845    fn request_result(
846        duration_ms: u64,
847        status: u16,
848        success: bool,
849        error: Option<&str>,
850    ) -> RequestResult {
851        RequestResult {
852            status,
853            success,
854            error: error.map(str::to_string),
855            timings: RequestTimings {
856                total_ms: duration_ms,
857                ttfb_ms: duration_ms / 2,
858                body_read_ms: duration_ms.saturating_sub(duration_ms / 2),
859            },
860        }
861    }
862
863    fn sample_bench_result() -> BenchResult {
864        BenchResult {
865            step_name: "test".into(),
866            method: "GET".into(),
867            url: "http://localhost".into(),
868            concurrency: 10,
869            ramp_up_ms: None,
870            total_requests: 10,
871            successful: 9,
872            failed: 1,
873            error_rate: 10.0,
874            total_duration_ms: 500,
875            throughput_rps: 20.0,
876            latency: LatencyStats {
877                min_ms: 5,
878                max_ms: 50,
879                mean_ms: 20.0,
880                median_ms: 18,
881                p95_ms: 45,
882                p99_ms: 50,
883                stdev_ms: 12.5,
884            },
885            timings: TimingBreakdown {
886                total: LatencyStats {
887                    min_ms: 5,
888                    max_ms: 50,
889                    mean_ms: 20.0,
890                    median_ms: 18,
891                    p95_ms: 45,
892                    p99_ms: 50,
893                    stdev_ms: 12.5,
894                },
895                ttfb: LatencyStats {
896                    min_ms: 2,
897                    max_ms: 20,
898                    mean_ms: 8.0,
899                    median_ms: 7,
900                    p95_ms: 18,
901                    p99_ms: 20,
902                    stdev_ms: 3.5,
903                },
904                body_read: LatencyStats {
905                    min_ms: 1,
906                    max_ms: 30,
907                    mean_ms: 12.0,
908                    median_ms: 10,
909                    p95_ms: 27,
910                    p99_ms: 30,
911                    stdev_ms: 7.0,
912                },
913                connect: None,
914                tls: None,
915            },
916            status_codes: HashMap::from([(200, 9), (500, 1)]),
917            errors: vec!["server error".into()],
918            gates: Vec::new(),
919            passed_gates: true,
920        }
921    }
922
923    #[tokio::test]
924    async fn execute_single_rejects_invalid_method_token() {
925        let client = reqwest::Client::new();
926        let result = execute_single(
927            &client,
928            "BAD METHOD",
929            "http://127.0.0.1:1",
930            &HashMap::new(),
931            None,
932            None,
933            None,
934        )
935        .await;
936
937        assert!(!result.success);
938        assert!(result.error.unwrap().contains("Invalid HTTP method"));
939    }
940
941    #[test]
942    fn percentile_empty() {
943        assert_eq!(percentile(&[], 50.0), 0);
944    }
945
946    #[test]
947    fn percentile_single() {
948        assert_eq!(percentile(&[42], 50.0), 42);
949        assert_eq!(percentile(&[42], 99.0), 42);
950    }
951
952    #[test]
953    fn percentile_multiple() {
954        let data: Vec<u64> = (1..=100).collect();
955        // p50 of 1..=100: index round(0.5*99)=50, data[50]=51
956        assert_eq!(percentile(&data, 50.0), 51);
957        assert_eq!(percentile(&data, 95.0), 95);
958        assert_eq!(percentile(&data, 99.0), 99);
959    }
960
961    #[test]
962    fn percentile_small_set() {
963        let data = vec![5, 10, 15, 20, 25];
964        assert_eq!(percentile(&data, 50.0), 15);
965    }
966
967    #[test]
968    fn aggregate_all_success() {
969        let results = vec![
970            request_result(10, 200, true, None),
971            request_result(20, 200, true, None),
972            request_result(30, 200, true, None),
973        ];
974        let agg = aggregate_results("test", "GET", "http://localhost", 3, None, results, 100);
975        assert_eq!(agg.total_requests, 3);
976        assert_eq!(agg.successful, 3);
977        assert_eq!(agg.failed, 0);
978        assert_eq!(agg.error_rate, 0.0);
979        assert_eq!(agg.latency.min_ms, 10);
980        assert_eq!(agg.latency.max_ms, 30);
981        assert_eq!(agg.latency.median_ms, 20);
982        assert_eq!(*agg.status_codes.get(&200).unwrap(), 3);
983    }
984
985    #[test]
986    fn aggregate_mixed_results() {
987        let results = vec![
988            request_result(10, 200, true, None),
989            request_result(5, 500, false, Some("server error")),
990            request_result(0, 0, false, Some("connection refused")),
991        ];
992        let agg = aggregate_results("test", "GET", "http://localhost", 2, None, results, 50);
993        assert_eq!(agg.total_requests, 3);
994        assert_eq!(agg.successful, 1);
995        assert_eq!(agg.failed, 2);
996        assert!(agg.error_rate > 60.0);
997        assert_eq!(agg.errors.len(), 2);
998        // Latency only from successful requests
999        assert_eq!(agg.latency.min_ms, 10);
1000        assert_eq!(agg.latency.max_ms, 10);
1001    }
1002
1003    #[test]
1004    fn aggregate_all_failures() {
1005        let results = vec![request_result(0, 0, false, Some("err"))];
1006        let agg = aggregate_results("test", "GET", "http://localhost", 1, None, results, 10);
1007        assert_eq!(agg.successful, 0);
1008        assert_eq!(agg.failed, 1);
1009        assert_eq!(agg.latency.min_ms, 0);
1010        assert_eq!(agg.latency.mean_ms, 0.0);
1011    }
1012
1013    #[test]
1014    fn aggregate_throughput() {
1015        let results = vec![
1016            request_result(10, 200, true, None),
1017            request_result(10, 200, true, None),
1018        ];
1019        // 2 requests in 100ms = 20 req/s
1020        let agg = aggregate_results("test", "GET", "http://localhost", 2, None, results, 100);
1021        assert_eq!(agg.throughput_rps, 20.0);
1022    }
1023
1024    #[test]
1025    fn aggregate_deduplicates_errors() {
1026        let results = vec![
1027            request_result(0, 0, false, Some("same error")),
1028            request_result(0, 0, false, Some("same error")),
1029            request_result(0, 0, false, Some("different")),
1030        ];
1031        let agg = aggregate_results("test", "GET", "http://localhost", 1, None, results, 10);
1032        assert_eq!(agg.errors.len(), 2);
1033    }
1034
1035    #[test]
1036    fn resolve_step_flat_steps() {
1037        let yaml = r#"
1038name: test
1039steps:
1040  - name: first
1041    request:
1042      method: GET
1043      url: "http://localhost"
1044  - name: second
1045    request:
1046      method: POST
1047      url: "http://localhost"
1048"#;
1049        let tf: crate::model::TestFile = serde_yaml::from_str(yaml).unwrap();
1050        let step = resolve_step(&tf, 0).unwrap();
1051        assert_eq!(step.name, "first");
1052        let step = resolve_step(&tf, 1).unwrap();
1053        assert_eq!(step.name, "second");
1054        assert!(resolve_step(&tf, 5).is_err());
1055    }
1056
1057    #[test]
1058    fn resolve_step_test_groups() {
1059        let yaml = r#"
1060name: test
1061tests:
1062  my_test:
1063    steps:
1064      - name: grouped
1065        request:
1066          method: GET
1067          url: "http://localhost"
1068"#;
1069        let tf: crate::model::TestFile = serde_yaml::from_str(yaml).unwrap();
1070        let step = resolve_step(&tf, 0).unwrap();
1071        assert_eq!(step.name, "grouped");
1072    }
1073
1074    #[test]
1075    fn render_json_output() {
1076        let result = sample_bench_result();
1077        let json = render_json(&result);
1078        let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
1079        assert_eq!(parsed["total_requests"], 10);
1080        assert_eq!(parsed["latency"]["p95_ms"], 45);
1081        assert_eq!(parsed["timings"]["ttfb"]["p95_ms"], 18);
1082    }
1083
1084    #[test]
1085    fn render_human_output() {
1086        let mut result = sample_bench_result();
1087        result.step_name = "health".into();
1088        result.url = "http://localhost/health".into();
1089        result.total_requests = 100;
1090        result.successful = 100;
1091        result.failed = 0;
1092        result.error_rate = 0.0;
1093        result.total_duration_ms = 1500;
1094        result.throughput_rps = 66.67;
1095        result.status_codes = HashMap::from([(200, 100)]);
1096        result.errors.clear();
1097        let output = render_human(&result);
1098        assert!(output.contains("TARN BENCH"));
1099        assert!(output.contains("100 total"));
1100        assert!(output.contains("66.7 req/s"));
1101        assert!(output.contains("p95"));
1102        assert!(output.contains("p99"));
1103        assert!(output.contains("body-read"));
1104    }
1105
1106    #[test]
1107    fn render_csv_output() {
1108        let csv = render_csv(&sample_bench_result());
1109        assert!(csv.contains("throughput_rps"));
1110        assert!(csv.contains("http://localhost"));
1111    }
1112
1113    #[test]
1114    fn evaluate_gates_reports_failures() {
1115        let mut result = sample_bench_result();
1116        result.throughput_rps = 15.0;
1117        result.error_rate = 12.0;
1118        result.latency.p95_ms = 80;
1119        let gates = evaluate_gates(
1120            &result,
1121            &BenchThresholds {
1122                min_throughput_rps: Some(20.0),
1123                max_error_rate: Some(5.0),
1124                max_p95_ms: Some(50),
1125                max_p99_ms: None,
1126            },
1127        );
1128        assert_eq!(gates.len(), 3);
1129        assert!(gates.iter().all(|gate| !gate.passed));
1130    }
1131}