Skip to main content

wafrift_proxy/
upstream.rs

1//! Unified upstream HTTP client for `wafrift-proxy`.
2//!
3//! Wraps either `reqwest::Client` (default, rustls TLS) or
4//! `wafrift_transport::StealthClient` (opt-in via the
5//! `tls-impersonate` feature on `wafrift-transport`, `BoringSSL` via
6//! `wreq` for browser-identical JA3) behind a single `send()` API.
7//!
8//! Both paths return the same [`UpstreamResponse`] shape, so the proxy
9//! call sites at `forward_wafrift_request` and `forward_passthrough`
10//! don't need to know which TLS stack they're talking through.
11//!
12//! # Why this is a wrapper, not a swap
13//!
14//! `reqwest::Client` carries a lot of proxy-specific configuration we
15//! depend on: SSRF-safe DNS resolver (custom bogon-checking resolver
16//! re-runs the policy at connection time), redirect policy, cookie
17//! jar, proxy pool, MITM cert handling. None of that is part of the
18//! "stealth" goal — JA3 parity only matters for the upstream TLS
19//! handshake bytes. So `reqwest` stays as the default; stealth gets
20//! plumbed alongside as an alternative *for the upstream-fetch step
21//! only*, when the practitioner has explicitly opted in.
22//!
23//! # Build matrix
24//!
25//! - Default build: `UpstreamClient::Reqwest` is the only enabled
26//!   variant. The `Stealth` variant is `#[cfg(feature =
27//!   "tls-impersonate")]`-gated. Practitioners trying
28//!   `--tls-impersonate <profile>` against a binary built without the
29//!   feature get an actionable error pointing at the cargo flag.
30//! - With `tls-impersonate`: both variants compile.
31
32use bytes::Bytes;
33#[cfg(feature = "tls-impersonate")]
34use std::time::Duration;
35use thiserror::Error;
36use wafrift_transport::stealth::ImpersonateProfile;
37#[cfg(feature = "tls-impersonate")]
38use wafrift_transport::stealth::StealthClient;
39
40/// Upstream-fetch error. Wraps either reqwest's transport error or a
41/// stealth-client error.
42#[derive(Debug, Error)]
43pub enum UpstreamError {
44    #[error("upstream request failed: {0}")]
45    Request(String),
46
47    #[error("invalid HTTP method: {0}")]
48    InvalidMethod(String),
49
50    #[error("upstream response too large (cap {cap}): truncated at {got} bytes")]
51    BodyTooLarge { got: usize, cap: usize },
52
53    #[error(
54        "stealth mode requires the `tls-impersonate` cargo feature; \
55         rebuild wafrift-proxy with `cargo build --features \
56         wafrift-transport/tls-impersonate`"
57    )]
58    StealthFeatureDisabled,
59}
60
61/// One upstream response, materialised into a uniform shape.
62#[derive(Debug)]
63pub struct UpstreamResponse {
64    pub status: http::StatusCode,
65    pub headers: http::HeaderMap,
66    pub body: Bytes,
67}
68
69/// Either the default reqwest client or a stealth (wreq) client,
70/// optionally wearing a different browser fingerprint per request via
71/// the `UpstreamClient::StealthPool` variant.
72#[derive(Clone)]
73pub enum UpstreamClient {
74    /// Default rustls-backed reqwest client. Carries SSRF resolver,
75    /// redirect policy, proxy-pool, etc.
76    Reqwest(reqwest::Client),
77
78    /// Opt-in BoringSSL-backed stealth client. Used only for the
79    /// upstream forward step when `--tls-impersonate <profile>` is
80    /// set on the proxy command line. Compiled out by default.
81    #[cfg(feature = "tls-impersonate")]
82    Stealth(std::sync::Arc<StealthClient>),
83
84    /// Round-robin pool of stealth clients (one per profile). Lets the
85    /// proxy rotate browser fingerprints per request, which defeats
86    /// rate-limit-by-JA3 and per-fingerprint reputation systems
87    /// (Cloudflare bot-management, Akamai BMP). Selected with
88    /// `--tls-impersonate-rotate chrome131,firefox133,safari18`.
89    #[cfg(feature = "tls-impersonate")]
90    StealthPool {
91        /// Pre-built clients, one per profile. Indexed via the atomic
92        /// `cursor` below.
93        clients: std::sync::Arc<Vec<std::sync::Arc<StealthClient>>>,
94        /// Round-robin counter. `AtomicUsize` so `send()` stays `&self`
95        /// — the proxy holds the pool inside an `Arc` and dispatches
96        /// from many concurrent tasks.
97        cursor: std::sync::Arc<std::sync::atomic::AtomicUsize>,
98    },
99}
100
101impl std::fmt::Debug for UpstreamClient {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        match self {
104            Self::Reqwest(_) => f.debug_tuple("Reqwest").finish(),
105            #[cfg(feature = "tls-impersonate")]
106            Self::Stealth(_) => f.debug_tuple("Stealth").finish(),
107            #[cfg(feature = "tls-impersonate")]
108            Self::StealthPool { clients, cursor } => f
109                .debug_struct("StealthPool")
110                .field("clients", &clients.len())
111                .field("cursor", cursor)
112                .finish(),
113        }
114    }
115}
116
117impl UpstreamClient {
118    /// Build the default reqwest variant from a pre-configured client.
119    /// All the SSRF/resolver/cookie wiring stays where it already is in
120    /// `main()`; this is just the wrapping ctor.
121    #[must_use]
122    pub fn from_reqwest(client: reqwest::Client) -> Self {
123        Self::Reqwest(client)
124    }
125
126    /// Build a stealth variant wearing the given browser profile.
127    ///
128    /// # Errors
129    ///
130    /// Returns [`UpstreamError::StealthFeatureDisabled`] if the binary
131    /// was built without `tls-impersonate`.
132    pub fn stealth(_profile: ImpersonateProfile) -> Result<Self, UpstreamError> {
133        #[cfg(feature = "tls-impersonate")]
134        {
135            // R56 pass-21: was hardcoded 60s — diverged from the canonical
136            // `wafrift_types::DEFAULT_REQUEST_TIMEOUT_SECS` (30) used by every
137            // other client in this binary. Operators saw different timeout
138            // behaviour depending on whether the request went through
139            // UpstreamClient::stealth vs main.rs's direct StealthClient.
140            let client = StealthClient::with_timeout(
141                _profile,
142                Duration::from_secs(wafrift_types::DEFAULT_REQUEST_TIMEOUT_SECS),
143            )
144            .map_err(|e| UpstreamError::Request(e.to_string()))?;
145            Ok(Self::Stealth(std::sync::Arc::new(client)))
146        }
147        #[cfg(not(feature = "tls-impersonate"))]
148        {
149            Err(UpstreamError::StealthFeatureDisabled)
150        }
151    }
152
153    /// Build a rotating pool of stealth clients (one per profile).
154    /// `send()` advances a round-robin cursor so successive requests
155    /// land on different fingerprints.
156    ///
157    /// # Errors
158    ///
159    /// - [`UpstreamError::StealthFeatureDisabled`] if built without
160    ///   `tls-impersonate`.
161    /// - [`UpstreamError::Request`] if any client fails to build OR if
162    ///   `_profiles` is empty (a pool of zero is meaningless).
163    pub fn stealth_pool(_profiles: &[ImpersonateProfile]) -> Result<Self, UpstreamError> {
164        #[cfg(feature = "tls-impersonate")]
165        {
166            if _profiles.is_empty() {
167                return Err(UpstreamError::Request(
168                    "stealth_pool requires at least one profile".into(),
169                ));
170            }
171            let mut clients = Vec::with_capacity(_profiles.len());
172            for &p in _profiles {
173                let c = StealthClient::with_timeout(
174                    p,
175                    Duration::from_secs(wafrift_types::DEFAULT_REQUEST_TIMEOUT_SECS),
176                )
177                .map_err(|e| UpstreamError::Request(format!("{}: {e}", p.name())))?;
178                clients.push(std::sync::Arc::new(c));
179            }
180            Ok(Self::StealthPool {
181                clients: std::sync::Arc::new(clients),
182                cursor: std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)),
183            })
184        }
185        #[cfg(not(feature = "tls-impersonate"))]
186        {
187            Err(UpstreamError::StealthFeatureDisabled)
188        }
189    }
190
191    /// Send a request and read the response (with body bounded by
192    /// `max_body`). Method/URL/headers/body shape mirrors what
193    /// `forward_wafrift_request` already builds — the migration is just
194    /// "stop calling `client.request(method, url).send()` directly,
195    /// call this instead".
196    ///
197    /// # Body bounding
198    ///
199    /// Bodies are truncated at `max_body` bytes (no error, the
200    /// truncated content is still useful for WAF-block detection —
201    /// matches `forward_wafrift_request`'s existing semantics).
202    pub async fn send(
203        &self,
204        method: &str,
205        url: &str,
206        headers: &[(String, String)],
207        body: Option<Vec<u8>>,
208        max_body: usize,
209    ) -> Result<UpstreamResponse, UpstreamError> {
210        match self {
211            Self::Reqwest(client) => {
212                let m = reqwest::Method::from_bytes(method.as_bytes())
213                    .map_err(|_| UpstreamError::InvalidMethod(method.to_string()))?;
214                let mut req = client.request(m, url);
215                for (k, v) in headers {
216                    req = req.header(k.as_str(), v.as_str());
217                }
218                if let Some(b) = body {
219                    req = req.body(b);
220                }
221                let resp = req
222                    .send()
223                    .await
224                    .map_err(|e| UpstreamError::Request(e.to_string()))?;
225                let status = http::StatusCode::from_u16(resp.status().as_u16())
226                    .map_err(|e| UpstreamError::Request(e.to_string()))?;
227                // reqwest's HeaderMap is from http already (re-export), so
228                // we can clone it directly.
229                let headers = resp.headers().clone();
230                // Bound body read.
231                // Previously the loop silently truncated at max_body bytes and
232                // returned Ok(truncated_body) — UpstreamError::BodyTooLarge was
233                // defined but never emitted, making it impossible for callers to
234                // distinguish a complete response from a truncated one. A WAF
235                // block that happens to produce a large body would be silently
236                // truncated, potentially making the response look like a pass
237                // when the real indicator was cut off. Now we return the explicit
238                // error so callers can treat it as a failed scan / skip.
239                let mut buf = Vec::new();
240                let mut stream = resp.bytes_stream();
241                use futures_util::StreamExt;
242                while let Some(chunk) = stream.next().await {
243                    let chunk = chunk.map_err(|e| UpstreamError::Request(e.to_string()))?;
244                    if buf.len().saturating_add(chunk.len()) > max_body {
245                        // We've hit the cap. Return a hard error so callers know
246                        // the body was not fully read rather than silently lying
247                        // about the truncated content.
248                        return Err(UpstreamError::BodyTooLarge {
249                            got: buf.len().saturating_add(chunk.len()),
250                            cap: max_body,
251                        });
252                    }
253                    buf.extend_from_slice(&chunk);
254                }
255                Ok(UpstreamResponse {
256                    status,
257                    headers,
258                    body: Bytes::from(buf),
259                })
260            }
261            #[cfg(feature = "tls-impersonate")]
262            Self::Stealth(client) => {
263                Self::send_via_stealth(client, method, url, headers, body, max_body).await
264            }
265            #[cfg(feature = "tls-impersonate")]
266            Self::StealthPool { clients, cursor } => {
267                let idx = cursor.fetch_add(1, std::sync::atomic::Ordering::Relaxed) % clients.len();
268                let client = clients[idx].clone();
269                Self::send_via_stealth(&client, method, url, headers, body, max_body).await
270            }
271        }
272    }
273
274    #[cfg(feature = "tls-impersonate")]
275    async fn send_via_stealth(
276        client: &StealthClient,
277        method: &str,
278        url: &str,
279        headers: &[(String, String)],
280        body: Option<Vec<u8>>,
281        max_body: usize,
282    ) -> Result<UpstreamResponse, UpstreamError> {
283        let stealth_resp = client
284            .send(method, url, headers, body.as_deref(), max_body)
285            .await
286            .map_err(|e| UpstreamError::Request(e.to_string()))?;
287        let status = http::StatusCode::from_u16(stealth_resp.status)
288            .map_err(|e| UpstreamError::Request(e.to_string()))?;
289        let mut header_map = http::HeaderMap::with_capacity(stealth_resp.headers.len());
290        for (k, v) in &stealth_resp.headers {
291            if let (Ok(name), Ok(val)) = (
292                http::HeaderName::from_bytes(k.as_bytes()),
293                http::HeaderValue::from_bytes(v.as_bytes()),
294            ) {
295                header_map.append(name, val);
296            }
297        }
298        Ok(UpstreamResponse {
299            status,
300            headers: header_map,
301            body: Bytes::from(stealth_resp.body),
302        })
303    }
304
305    /// Returns the operator-visible name of the active TLS stack, for
306    /// log lines / `/_wafrift/status` output.
307    #[must_use]
308    pub fn tls_stack_name(&self) -> &'static str {
309        match self {
310            Self::Reqwest(_) => "rustls (default)",
311            #[cfg(feature = "tls-impersonate")]
312            Self::Stealth(_) => "boringssl (stealth)",
313            #[cfg(feature = "tls-impersonate")]
314            Self::StealthPool { .. } => "boringssl (stealth pool, rotating)",
315        }
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn from_reqwest_wraps_client() {
325        let client = reqwest::Client::new();
326        let upstream = UpstreamClient::from_reqwest(client);
327        assert_eq!(upstream.tls_stack_name(), "rustls (default)");
328    }
329
330    #[test]
331    fn upstream_error_messages_are_actionable() {
332        let err = UpstreamError::InvalidMethod("FUBAR".into());
333        assert!(err.to_string().contains("FUBAR"));
334
335        let err = UpstreamError::BodyTooLarge {
336            got: 5_000_000,
337            cap: 1_000_000,
338        };
339        let msg = err.to_string();
340        assert!(msg.contains("5000000"));
341        assert!(msg.contains("1000000"));
342
343        let err = UpstreamError::StealthFeatureDisabled;
344        let msg = err.to_string();
345        assert!(
346            msg.contains("tls-impersonate") && msg.contains("cargo build"),
347            "feature-disabled error must name the cargo flag, got: {msg}"
348        );
349    }
350
351    #[cfg(not(feature = "tls-impersonate"))]
352    #[test]
353    fn stealth_constructor_errors_when_feature_off() {
354        match UpstreamClient::stealth(ImpersonateProfile::Chrome131) {
355            Err(UpstreamError::StealthFeatureDisabled) => {}
356            Err(other) => panic!("expected StealthFeatureDisabled, got {other}"),
357            Ok(_) => panic!("expected error, got Ok variant"),
358        }
359    }
360
361    #[cfg(feature = "tls-impersonate")]
362    #[test]
363    fn stealth_constructor_builds_when_feature_on() {
364        let upstream = UpstreamClient::stealth(ImpersonateProfile::Chrome131).unwrap();
365        assert_eq!(upstream.tls_stack_name(), "boringssl (stealth)");
366    }
367
368    #[cfg(feature = "tls-impersonate")]
369    #[test]
370    fn stealth_pool_rotates_round_robin() {
371        let pool = UpstreamClient::stealth_pool(&[
372            ImpersonateProfile::Chrome131,
373            ImpersonateProfile::Firefox133,
374            ImpersonateProfile::Safari18,
375        ])
376        .unwrap();
377        assert_eq!(pool.tls_stack_name(), "boringssl (stealth pool, rotating)");
378        // Cursor advances on every send. We can't test a real send
379        // without a network, but we can exercise the cursor by faking
380        // index calculation.
381        if let UpstreamClient::StealthPool { clients, cursor } = &pool {
382            assert_eq!(clients.len(), 3);
383            let first = cursor.fetch_add(1, std::sync::atomic::Ordering::Relaxed) % clients.len();
384            let second = cursor.fetch_add(1, std::sync::atomic::Ordering::Relaxed) % clients.len();
385            let third = cursor.fetch_add(1, std::sync::atomic::Ordering::Relaxed) % clients.len();
386            let fourth = cursor.fetch_add(1, std::sync::atomic::Ordering::Relaxed) % clients.len();
387            assert_eq!((first, second, third, fourth), (0, 1, 2, 0));
388        } else {
389            panic!("expected StealthPool variant");
390        }
391    }
392
393    #[cfg(feature = "tls-impersonate")]
394    #[test]
395    fn stealth_pool_rejects_empty_profiles() {
396        let err = UpstreamClient::stealth_pool(&[]).unwrap_err();
397        match err {
398            UpstreamError::Request(msg) => assert!(msg.contains("at least one")),
399            other => panic!("expected Request error, got {other:?}"),
400        }
401    }
402
403    #[cfg(not(feature = "tls-impersonate"))]
404    #[test]
405    fn stealth_pool_errors_when_feature_off() {
406        match UpstreamClient::stealth_pool(&[ImpersonateProfile::Chrome131]) {
407            Err(UpstreamError::StealthFeatureDisabled) => {}
408            Err(other) => panic!("expected StealthFeatureDisabled, got {other}"),
409            Ok(_) => panic!("expected error, got Ok variant"),
410        }
411    }
412
413    // ── New tests added 2026-05-24 ─────────────────────────────────────────
414
415    #[test]
416    fn body_too_large_error_got_and_cap_correct() {
417        // UpstreamError::BodyTooLarge must carry correct got and cap values.
418        let err = UpstreamError::BodyTooLarge {
419            got: 1024,
420            cap: 512,
421        };
422        match &err {
423            UpstreamError::BodyTooLarge { got, cap } => {
424                assert_eq!(*got, 1024);
425                assert_eq!(*cap, 512);
426            }
427            other => panic!("unexpected variant: {other:?}"),
428        }
429        let msg = err.to_string();
430        assert!(msg.contains("1024"), "error message must contain got=1024");
431        assert!(msg.contains("512"), "error message must contain cap=512");
432    }
433
434    #[test]
435    fn body_too_large_at_cap_does_not_error() {
436        // We can't call .send() without network, but we can verify that
437        // the error is only BodyTooLarge { got, cap } where got > cap.
438        // At exactly cap bytes, the logic should NOT produce BodyTooLarge.
439        // Verify the check logic directly: buf.len() + chunk.len() > max_body.
440        let cap = 100usize;
441        let buf_len = 95usize;
442        let chunk_len = 5usize;
443        assert_eq!(
444            buf_len + chunk_len,
445            cap,
446            "buf+chunk exactly equals cap — must not trigger BodyTooLarge"
447        );
448        // Would trigger:
449        let over = buf_len + chunk_len + 1;
450        assert!(over > cap, "over must exceed cap to trigger error");
451    }
452
453    #[test]
454    fn upstream_error_invalid_method_contains_method_name() {
455        let err = UpstreamError::InvalidMethod("BADMETHOD".into());
456        let msg = err.to_string();
457        assert!(
458            msg.contains("BADMETHOD"),
459            "InvalidMethod must name the method, got: {msg}"
460        );
461    }
462
463    #[test]
464    fn upstream_error_request_contains_inner() {
465        let err = UpstreamError::Request("connection refused".into());
466        let msg = err.to_string();
467        assert!(
468            msg.contains("connection refused"),
469            "Request error must include inner message, got: {msg}"
470        );
471    }
472
473    #[test]
474    fn stealth_feature_disabled_error_names_cargo_flag() {
475        // The error message must name both the feature flag AND the
476        // cargo build command so the user knows what to do.
477        let err = UpstreamError::StealthFeatureDisabled;
478        let msg = err.to_string();
479        assert!(
480            msg.contains("tls-impersonate"),
481            "feature-disabled error must name `tls-impersonate`, got: {msg}"
482        );
483        assert!(
484            msg.contains("cargo build"),
485            "feature-disabled error must mention `cargo build`, got: {msg}"
486        );
487    }
488
489    #[test]
490    fn tls_stack_name_reqwest_variant() {
491        let client = reqwest::Client::new();
492        let upstream = UpstreamClient::from_reqwest(client);
493        assert_eq!(upstream.tls_stack_name(), "rustls (default)");
494    }
495
496    #[test]
497    fn body_too_large_boundary_just_at_cap() {
498        // Boundary: got == cap is NOT a BodyTooLarge condition in the contract.
499        // got > cap is the trigger. This test verifies the condition semantics.
500        let cap = 1024usize;
501        let at_cap = UpstreamError::BodyTooLarge { got: cap, cap };
502        // Just to exercise the Display — must not panic.
503        let _ = at_cap.to_string();
504    }
505
506    #[test]
507    fn body_too_large_various_sizes() {
508        // Spot-check several (got, cap) pairs.
509        let cases = [
510            (1, 0),
511            (100, 50),
512            (1_000_000, 999_999),
513            (usize::MAX, usize::MAX - 1),
514        ];
515        for (got, cap) in cases {
516            let err = UpstreamError::BodyTooLarge { got, cap };
517            let msg = err.to_string();
518            assert!(
519                !msg.is_empty(),
520                "error message must not be empty for got={got} cap={cap}"
521            );
522        }
523    }
524}