1use bytes::Bytes;
33#[cfg(feature = "tls-impersonate")]
34use std::time::Duration;
35use thiserror::Error;
36use wafrift_transport::stealth::ImpersonateProfile;
37#[cfg(feature = "tls-impersonate")]
38use wafrift_transport::stealth::StealthClient;
39
40#[derive(Debug, Error)]
43pub enum UpstreamError {
44 #[error("upstream request failed: {0}")]
45 Request(String),
46
47 #[error("invalid HTTP method: {0}")]
48 InvalidMethod(String),
49
50 #[error("upstream response too large (cap {cap}): truncated at {got} bytes")]
51 BodyTooLarge { got: usize, cap: usize },
52
53 #[error(
54 "stealth mode requires the `tls-impersonate` cargo feature; \
55 rebuild wafrift-proxy with `cargo build --features \
56 wafrift-transport/tls-impersonate`"
57 )]
58 StealthFeatureDisabled,
59}
60
61#[derive(Debug)]
63pub struct UpstreamResponse {
64 pub status: http::StatusCode,
65 pub headers: http::HeaderMap,
66 pub body: Bytes,
67}
68
69#[derive(Clone)]
73pub enum UpstreamClient {
74 Reqwest(reqwest::Client),
77
78 #[cfg(feature = "tls-impersonate")]
82 Stealth(std::sync::Arc<StealthClient>),
83
84 #[cfg(feature = "tls-impersonate")]
90 StealthPool {
91 clients: std::sync::Arc<Vec<std::sync::Arc<StealthClient>>>,
94 cursor: std::sync::Arc<std::sync::atomic::AtomicUsize>,
98 },
99}
100
101impl std::fmt::Debug for UpstreamClient {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 match self {
104 Self::Reqwest(_) => f.debug_tuple("Reqwest").finish(),
105 #[cfg(feature = "tls-impersonate")]
106 Self::Stealth(_) => f.debug_tuple("Stealth").finish(),
107 #[cfg(feature = "tls-impersonate")]
108 Self::StealthPool { clients, cursor } => f
109 .debug_struct("StealthPool")
110 .field("clients", &clients.len())
111 .field("cursor", cursor)
112 .finish(),
113 }
114 }
115}
116
117impl UpstreamClient {
118 #[must_use]
122 pub fn from_reqwest(client: reqwest::Client) -> Self {
123 Self::Reqwest(client)
124 }
125
126 pub fn stealth(_profile: ImpersonateProfile) -> Result<Self, UpstreamError> {
133 #[cfg(feature = "tls-impersonate")]
134 {
135 let client = StealthClient::with_timeout(
141 _profile,
142 Duration::from_secs(wafrift_types::DEFAULT_REQUEST_TIMEOUT_SECS),
143 )
144 .map_err(|e| UpstreamError::Request(e.to_string()))?;
145 Ok(Self::Stealth(std::sync::Arc::new(client)))
146 }
147 #[cfg(not(feature = "tls-impersonate"))]
148 {
149 Err(UpstreamError::StealthFeatureDisabled)
150 }
151 }
152
153 pub fn stealth_pool(_profiles: &[ImpersonateProfile]) -> Result<Self, UpstreamError> {
164 #[cfg(feature = "tls-impersonate")]
165 {
166 if _profiles.is_empty() {
167 return Err(UpstreamError::Request(
168 "stealth_pool requires at least one profile".into(),
169 ));
170 }
171 let mut clients = Vec::with_capacity(_profiles.len());
172 for &p in _profiles {
173 let c = StealthClient::with_timeout(
174 p,
175 Duration::from_secs(wafrift_types::DEFAULT_REQUEST_TIMEOUT_SECS),
176 )
177 .map_err(|e| UpstreamError::Request(format!("{}: {e}", p.name())))?;
178 clients.push(std::sync::Arc::new(c));
179 }
180 Ok(Self::StealthPool {
181 clients: std::sync::Arc::new(clients),
182 cursor: std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0)),
183 })
184 }
185 #[cfg(not(feature = "tls-impersonate"))]
186 {
187 Err(UpstreamError::StealthFeatureDisabled)
188 }
189 }
190
191 pub async fn send(
203 &self,
204 method: &str,
205 url: &str,
206 headers: &[(String, String)],
207 body: Option<Vec<u8>>,
208 max_body: usize,
209 ) -> Result<UpstreamResponse, UpstreamError> {
210 match self {
211 Self::Reqwest(client) => {
212 let m = reqwest::Method::from_bytes(method.as_bytes())
213 .map_err(|_| UpstreamError::InvalidMethod(method.to_string()))?;
214 let mut req = client.request(m, url);
215 for (k, v) in headers {
216 req = req.header(k.as_str(), v.as_str());
217 }
218 if let Some(b) = body {
219 req = req.body(b);
220 }
221 let resp = req
222 .send()
223 .await
224 .map_err(|e| UpstreamError::Request(e.to_string()))?;
225 let status = http::StatusCode::from_u16(resp.status().as_u16())
226 .map_err(|e| UpstreamError::Request(e.to_string()))?;
227 let headers = resp.headers().clone();
230 let mut buf = Vec::new();
240 let mut stream = resp.bytes_stream();
241 use futures_util::StreamExt;
242 while let Some(chunk) = stream.next().await {
243 let chunk = chunk.map_err(|e| UpstreamError::Request(e.to_string()))?;
244 if buf.len().saturating_add(chunk.len()) > max_body {
245 return Err(UpstreamError::BodyTooLarge {
249 got: buf.len().saturating_add(chunk.len()),
250 cap: max_body,
251 });
252 }
253 buf.extend_from_slice(&chunk);
254 }
255 Ok(UpstreamResponse {
256 status,
257 headers,
258 body: Bytes::from(buf),
259 })
260 }
261 #[cfg(feature = "tls-impersonate")]
262 Self::Stealth(client) => {
263 Self::send_via_stealth(client, method, url, headers, body, max_body).await
264 }
265 #[cfg(feature = "tls-impersonate")]
266 Self::StealthPool { clients, cursor } => {
267 let idx = cursor.fetch_add(1, std::sync::atomic::Ordering::Relaxed) % clients.len();
268 let client = clients[idx].clone();
269 Self::send_via_stealth(&client, method, url, headers, body, max_body).await
270 }
271 }
272 }
273
274 #[cfg(feature = "tls-impersonate")]
275 async fn send_via_stealth(
276 client: &StealthClient,
277 method: &str,
278 url: &str,
279 headers: &[(String, String)],
280 body: Option<Vec<u8>>,
281 max_body: usize,
282 ) -> Result<UpstreamResponse, UpstreamError> {
283 let stealth_resp = client
284 .send(method, url, headers, body.as_deref(), max_body)
285 .await
286 .map_err(|e| UpstreamError::Request(e.to_string()))?;
287 let status = http::StatusCode::from_u16(stealth_resp.status)
288 .map_err(|e| UpstreamError::Request(e.to_string()))?;
289 let mut header_map = http::HeaderMap::with_capacity(stealth_resp.headers.len());
290 for (k, v) in &stealth_resp.headers {
291 if let (Ok(name), Ok(val)) = (
292 http::HeaderName::from_bytes(k.as_bytes()),
293 http::HeaderValue::from_bytes(v.as_bytes()),
294 ) {
295 header_map.append(name, val);
296 }
297 }
298 Ok(UpstreamResponse {
299 status,
300 headers: header_map,
301 body: Bytes::from(stealth_resp.body),
302 })
303 }
304
305 #[must_use]
308 pub fn tls_stack_name(&self) -> &'static str {
309 match self {
310 Self::Reqwest(_) => "rustls (default)",
311 #[cfg(feature = "tls-impersonate")]
312 Self::Stealth(_) => "boringssl (stealth)",
313 #[cfg(feature = "tls-impersonate")]
314 Self::StealthPool { .. } => "boringssl (stealth pool, rotating)",
315 }
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn from_reqwest_wraps_client() {
325 let client = reqwest::Client::new();
326 let upstream = UpstreamClient::from_reqwest(client);
327 assert_eq!(upstream.tls_stack_name(), "rustls (default)");
328 }
329
330 #[test]
331 fn upstream_error_messages_are_actionable() {
332 let err = UpstreamError::InvalidMethod("FUBAR".into());
333 assert!(err.to_string().contains("FUBAR"));
334
335 let err = UpstreamError::BodyTooLarge {
336 got: 5_000_000,
337 cap: 1_000_000,
338 };
339 let msg = err.to_string();
340 assert!(msg.contains("5000000"));
341 assert!(msg.contains("1000000"));
342
343 let err = UpstreamError::StealthFeatureDisabled;
344 let msg = err.to_string();
345 assert!(
346 msg.contains("tls-impersonate") && msg.contains("cargo build"),
347 "feature-disabled error must name the cargo flag, got: {msg}"
348 );
349 }
350
351 #[cfg(not(feature = "tls-impersonate"))]
352 #[test]
353 fn stealth_constructor_errors_when_feature_off() {
354 match UpstreamClient::stealth(ImpersonateProfile::Chrome131) {
355 Err(UpstreamError::StealthFeatureDisabled) => {}
356 Err(other) => panic!("expected StealthFeatureDisabled, got {other}"),
357 Ok(_) => panic!("expected error, got Ok variant"),
358 }
359 }
360
361 #[cfg(feature = "tls-impersonate")]
362 #[test]
363 fn stealth_constructor_builds_when_feature_on() {
364 let upstream = UpstreamClient::stealth(ImpersonateProfile::Chrome131).unwrap();
365 assert_eq!(upstream.tls_stack_name(), "boringssl (stealth)");
366 }
367
368 #[cfg(feature = "tls-impersonate")]
369 #[test]
370 fn stealth_pool_rotates_round_robin() {
371 let pool = UpstreamClient::stealth_pool(&[
372 ImpersonateProfile::Chrome131,
373 ImpersonateProfile::Firefox133,
374 ImpersonateProfile::Safari18,
375 ])
376 .unwrap();
377 assert_eq!(pool.tls_stack_name(), "boringssl (stealth pool, rotating)");
378 if let UpstreamClient::StealthPool { clients, cursor } = &pool {
382 assert_eq!(clients.len(), 3);
383 let first = cursor.fetch_add(1, std::sync::atomic::Ordering::Relaxed) % clients.len();
384 let second = cursor.fetch_add(1, std::sync::atomic::Ordering::Relaxed) % clients.len();
385 let third = cursor.fetch_add(1, std::sync::atomic::Ordering::Relaxed) % clients.len();
386 let fourth = cursor.fetch_add(1, std::sync::atomic::Ordering::Relaxed) % clients.len();
387 assert_eq!((first, second, third, fourth), (0, 1, 2, 0));
388 } else {
389 panic!("expected StealthPool variant");
390 }
391 }
392
393 #[cfg(feature = "tls-impersonate")]
394 #[test]
395 fn stealth_pool_rejects_empty_profiles() {
396 let err = UpstreamClient::stealth_pool(&[]).unwrap_err();
397 match err {
398 UpstreamError::Request(msg) => assert!(msg.contains("at least one")),
399 other => panic!("expected Request error, got {other:?}"),
400 }
401 }
402
403 #[cfg(not(feature = "tls-impersonate"))]
404 #[test]
405 fn stealth_pool_errors_when_feature_off() {
406 match UpstreamClient::stealth_pool(&[ImpersonateProfile::Chrome131]) {
407 Err(UpstreamError::StealthFeatureDisabled) => {}
408 Err(other) => panic!("expected StealthFeatureDisabled, got {other}"),
409 Ok(_) => panic!("expected error, got Ok variant"),
410 }
411 }
412
413 #[test]
416 fn body_too_large_error_got_and_cap_correct() {
417 let err = UpstreamError::BodyTooLarge {
419 got: 1024,
420 cap: 512,
421 };
422 match &err {
423 UpstreamError::BodyTooLarge { got, cap } => {
424 assert_eq!(*got, 1024);
425 assert_eq!(*cap, 512);
426 }
427 other => panic!("unexpected variant: {other:?}"),
428 }
429 let msg = err.to_string();
430 assert!(msg.contains("1024"), "error message must contain got=1024");
431 assert!(msg.contains("512"), "error message must contain cap=512");
432 }
433
434 #[test]
435 fn body_too_large_at_cap_does_not_error() {
436 let cap = 100usize;
441 let buf_len = 95usize;
442 let chunk_len = 5usize;
443 assert_eq!(
444 buf_len + chunk_len,
445 cap,
446 "buf+chunk exactly equals cap — must not trigger BodyTooLarge"
447 );
448 let over = buf_len + chunk_len + 1;
450 assert!(over > cap, "over must exceed cap to trigger error");
451 }
452
453 #[test]
454 fn upstream_error_invalid_method_contains_method_name() {
455 let err = UpstreamError::InvalidMethod("BADMETHOD".into());
456 let msg = err.to_string();
457 assert!(
458 msg.contains("BADMETHOD"),
459 "InvalidMethod must name the method, got: {msg}"
460 );
461 }
462
463 #[test]
464 fn upstream_error_request_contains_inner() {
465 let err = UpstreamError::Request("connection refused".into());
466 let msg = err.to_string();
467 assert!(
468 msg.contains("connection refused"),
469 "Request error must include inner message, got: {msg}"
470 );
471 }
472
473 #[test]
474 fn stealth_feature_disabled_error_names_cargo_flag() {
475 let err = UpstreamError::StealthFeatureDisabled;
478 let msg = err.to_string();
479 assert!(
480 msg.contains("tls-impersonate"),
481 "feature-disabled error must name `tls-impersonate`, got: {msg}"
482 );
483 assert!(
484 msg.contains("cargo build"),
485 "feature-disabled error must mention `cargo build`, got: {msg}"
486 );
487 }
488
489 #[test]
490 fn tls_stack_name_reqwest_variant() {
491 let client = reqwest::Client::new();
492 let upstream = UpstreamClient::from_reqwest(client);
493 assert_eq!(upstream.tls_stack_name(), "rustls (default)");
494 }
495
496 #[test]
497 fn body_too_large_boundary_just_at_cap() {
498 let cap = 1024usize;
501 let at_cap = UpstreamError::BodyTooLarge { got: cap, cap };
502 let _ = at_cap.to_string();
504 }
505
506 #[test]
507 fn body_too_large_various_sizes() {
508 let cases = [
510 (1, 0),
511 (100, 50),
512 (1_000_000, 999_999),
513 (usize::MAX, usize::MAX - 1),
514 ];
515 for (got, cap) in cases {
516 let err = UpstreamError::BodyTooLarge { got, cap };
517 let msg = err.to_string();
518 assert!(
519 !msg.is_empty(),
520 "error message must not be empty for got={got} cap={cap}"
521 );
522 }
523 }
524}