1use base64::Engine;
11use bytes::Bytes;
12use http::{Method, Uri};
13use serde::Serialize;
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::time::Duration;
17use tokio::sync::RwLock;
18use tokio::time::timeout as tokio_timeout;
19use url::Url;
20
21use crate::cookie::CookieJar;
22use crate::error::{Error, Result};
23use crate::fingerprint::{http2::Http2Settings, FingerprintProfile};
24use crate::headers::Headers;
25use crate::pool::alt_svc::AltSvcCache;
26use crate::pool::multiplexer::{ConnectionPool, PoolKey};
27use crate::request::{Body, IntoUrl, RedirectPolicy, Request};
28use crate::response::Response;
29use crate::timeouts::Timeouts;
30use crate::transport::connector::{BoringConnector, MaybeHttpsStream};
31use crate::transport::h1::H1Connection;
32use crate::transport::h2::{H2Connection, H2PooledConnection, H2Tunnel, PseudoHeaderOrder};
33use crate::transport::h3::{H3Client, H3Tunnel};
34use crate::version::HttpVersion;
35use crate::websocket::{WebSocketBuilder, WebSocketClientParts};
36
37#[derive(Clone)]
46pub struct Client {
47 connector: BoringConnector,
48 insecure_connector: BoringConnector,
50 h3_client: H3Client,
51 alt_svc_cache: Arc<AltSvcCache>,
52 h2_pool: Arc<RwLock<HashMap<PoolKey, H2PooledConnection>>>,
54 h1_pool: Arc<ConnectionPool>,
56 http2_settings: Http2Settings,
57 pseudo_order: PseudoHeaderOrder,
58 default_version: HttpVersion,
59 timeouts: Timeouts,
61 h3_upgrade_enabled: bool,
63 http2_prior_knowledge: bool,
65 danger_accept_invalid_certs: bool,
67 localhost_allows_invalid_certs: bool,
69 default_headers: Headers,
71 redirect_policy: RedirectPolicy,
73 cookie_store: Option<Arc<RwLock<CookieJar>>>,
75}
76
77pub struct RequestBuilder<'a> {
79 client: &'a Client,
80 url: Option<Url>,
81 method: Method,
82 headers: Headers,
83 body: Body,
84 version: Option<HttpVersion>,
85 timeout: Option<Duration>,
86 error: Option<Error>,
87}
88
89pub struct WebSocketH2Builder<'a> {
91 client: &'a Client,
92 url: Option<Url>,
93 headers: Headers,
94 error: Option<Error>,
95}
96
97pub struct WebSocketH3Builder<'a> {
99 client: &'a Client,
100 url: Option<Url>,
101 headers: Headers,
102 error: Option<Error>,
103}
104
105pub struct ClientBuilder {
107 fingerprint: FingerprintProfile,
108 http2_settings: Option<Http2Settings>,
109 pseudo_order: PseudoHeaderOrder,
110 timeouts: Timeouts,
111 prefer_http2: bool,
112 h3_upgrade_enabled: bool,
113 http2_prior_knowledge: bool,
114 root_certs: Vec<Vec<u8>>,
115 use_platform_roots: bool,
117 danger_accept_invalid_certs: bool,
119 localhost_allows_invalid_certs: bool,
121 default_headers: Headers,
123 redirect_policy: RedirectPolicy,
125 cookie_store: Option<Arc<RwLock<CookieJar>>>,
127}
128
129impl Client {
130 pub fn new() -> Result<Self> {
132 ClientBuilder::new().build()
133 }
134
135 pub fn builder() -> ClientBuilder {
137 ClientBuilder::new()
138 }
139
140 pub fn get(&self, url: impl IntoUrl) -> RequestBuilder<'_> {
142 RequestBuilder::new(self, Method::GET, url)
143 }
144
145 pub fn post(&self, url: impl IntoUrl) -> RequestBuilder<'_> {
147 RequestBuilder::new(self, Method::POST, url)
148 }
149
150 pub fn put(&self, url: impl IntoUrl) -> RequestBuilder<'_> {
152 RequestBuilder::new(self, Method::PUT, url)
153 }
154
155 pub fn delete(&self, url: impl IntoUrl) -> RequestBuilder<'_> {
157 RequestBuilder::new(self, Method::DELETE, url)
158 }
159
160 pub fn head(&self, url: impl IntoUrl) -> RequestBuilder<'_> {
162 RequestBuilder::new(self, Method::HEAD, url)
163 }
164
165 pub fn patch(&self, url: impl IntoUrl) -> RequestBuilder<'_> {
167 RequestBuilder::new(self, Method::PATCH, url)
168 }
169
170 pub fn request(&self, method: Method, url: impl IntoUrl) -> RequestBuilder<'_> {
172 RequestBuilder::new(self, method, url)
173 }
174
175 pub fn websocket_h2(&self, url: impl IntoUrl) -> WebSocketH2Builder<'_> {
177 WebSocketH2Builder::new(self, url)
178 }
179
180 pub fn websocket_h3(&self, url: impl IntoUrl) -> WebSocketH3Builder<'_> {
182 WebSocketH3Builder::new(self, url)
183 }
184
185 pub fn websocket(&self, url: impl IntoUrl) -> WebSocketBuilder<'_> {
187 Client::websocket_with_parts(
188 WebSocketClientParts {
189 connector: &self.connector,
190 insecure_connector: &self.insecure_connector,
191 default_headers: &self.default_headers,
192 timeouts: &self.timeouts,
193 cookie_store: self.cookie_store.as_ref(),
194 danger_accept_invalid_certs: self.danger_accept_invalid_certs,
195 localhost_allows_invalid_certs: self.localhost_allows_invalid_certs,
196 },
197 url,
198 )
199 }
200
201 pub fn alt_svc_cache(&self) -> &Arc<AltSvcCache> {
203 &self.alt_svc_cache
204 }
205
206 fn is_localhost(host: &str) -> bool {
208 host == "localhost" || host == "127.0.0.1" || host == "::1"
209 }
210
211 fn connector_for_uri(&self, uri: &Uri) -> &BoringConnector {
213 if self.danger_accept_invalid_certs {
215 return &self.insecure_connector;
216 }
217
218 if self.localhost_allows_invalid_certs {
220 if let Some(host) = uri.host() {
221 if Self::is_localhost(host) {
222 return &self.insecure_connector;
223 }
224 }
225 }
226
227 &self.connector
228 }
229}
230
231impl<'a> WebSocketH2Builder<'a> {
232 fn new(client: &'a Client, url: impl IntoUrl) -> Self {
233 let mut error = None;
234 let url = match url.into_url() {
235 Ok(url) => Some(url),
236 Err(err) => {
237 error = Some(err);
238 None
239 }
240 };
241
242 Self {
243 client,
244 url,
245 headers: client.default_headers.clone(),
246 error,
247 }
248 }
249
250 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
252 self.headers.insert(key, value);
253 self
254 }
255
256 pub fn headers(mut self, headers: impl Into<Headers>) -> Self {
258 self.headers = headers.into();
259 self
260 }
261
262 pub async fn open(self) -> Result<H2Tunnel> {
264 if let Some(err) = self.error {
265 return Err(err);
266 }
267
268 let url = self.url.ok_or_else(|| Error::missing("websocket URL"))?;
269
270 let websocket_scheme = url.scheme();
271 let h2_scheme = match websocket_scheme {
272 "wss" => "https",
273 "ws" => {
274 if !self.client.http2_prior_knowledge {
275 return Err(Error::WebSocketUnsupported(
276 "ws:// RFC 8441 requires explicit HTTP/2 prior knowledge".into(),
277 ));
278 }
279 "http"
280 }
281 other => {
282 return Err(Error::WebSocketUnsupported(format!(
283 "RFC 8441 requires ws:// or wss:// URL, got {other}"
284 )));
285 }
286 };
287
288 let mut h2_url = url.clone();
289 h2_url
290 .set_scheme(h2_scheme)
291 .map_err(|_| Error::WebSocketUnsupported("invalid WebSocket URL scheme".into()))?;
292
293 let uri: Uri = h2_url
294 .as_str()
295 .parse()
296 .map_err(|e| Error::HttpProtocol(format!("Invalid URI: {}", e)))?;
297
298 let headers = self.headers.to_vec();
299 let pool_key = Client::make_pool_key(&uri);
300
301 if let Some(conn) = {
302 let pool = self.client.h2_pool.read().await;
303 pool.get(&pool_key).cloned()
304 } {
305 match conn
306 .open_websocket_tunnel(uri.clone(), headers.clone())
307 .await
308 {
309 Ok(tunnel) => return Ok(tunnel),
310 Err(err) => {
311 tracing::debug!("Pooled RFC 8441 tunnel open failed, reconnecting: {}", err);
312 let mut pool = self.client.h2_pool.write().await;
313 pool.remove(&pool_key);
314 }
315 }
316 }
317
318 let connector = self.client.connector_for_uri(&uri);
319 let stream = connector.connect(&uri).await?;
320
321 let use_http2 = if websocket_scheme == "ws" && self.client.http2_prior_knowledge {
322 true
323 } else if let MaybeHttpsStream::Https(ref ssl_stream) = stream {
324 ssl_stream.ssl().selected_alpn_protocol() == Some(b"h2")
325 } else {
326 false
327 };
328
329 if !use_http2 {
330 return Err(Error::WebSocketUnsupported(
331 "RFC 8441 WebSocket requires ALPN h2 or explicit HTTP/2 prior knowledge".into(),
332 ));
333 }
334
335 let h2_conn = H2Connection::connect(
336 stream,
337 self.client.http2_settings.clone(),
338 self.client.pseudo_order,
339 )
340 .await?;
341 let pooled_conn = H2PooledConnection::new(h2_conn);
342
343 {
344 let mut pool = self.client.h2_pool.write().await;
345 pool.insert(pool_key, pooled_conn.clone());
346 }
347
348 pooled_conn.open_websocket_tunnel(uri, headers).await
349 }
350}
351
352impl<'a> WebSocketH3Builder<'a> {
353 fn new(client: &'a Client, url: impl IntoUrl) -> Self {
354 let mut error = None;
355 let url = match url.into_url() {
356 Ok(url) => Some(url),
357 Err(err) => {
358 error = Some(err);
359 None
360 }
361 };
362
363 Self {
364 client,
365 url,
366 headers: client.default_headers.clone(),
367 error,
368 }
369 }
370
371 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
373 self.headers.insert(key, value);
374 self
375 }
376
377 pub fn headers(mut self, headers: impl Into<Headers>) -> Self {
379 self.headers = headers.into();
380 self
381 }
382
383 pub async fn open(self) -> Result<H3Tunnel> {
385 if let Some(err) = self.error {
386 return Err(err);
387 }
388
389 let url = self.url.ok_or_else(|| Error::missing("websocket URL"))?;
390 if url.scheme() != "wss" {
391 return Err(Error::WebSocketUnsupported(
392 "RFC 9220 WebSocket over HTTP/3 requires wss://".into(),
393 ));
394 }
395
396 let mut h3_url = url.clone();
397 h3_url
398 .set_scheme("https")
399 .map_err(|_| Error::WebSocketUnsupported("invalid WebSocket URL scheme".into()))?;
400
401 let mut h3_client = self.client.h3_client.clone();
402 if self.client.danger_accept_invalid_certs
403 || (self.client.localhost_allows_invalid_certs
404 && h3_url
405 .host_str()
406 .is_some_and(|host| Client::is_localhost(host)))
407 {
408 h3_client = h3_client.danger_accept_invalid_certs(true);
409 }
410
411 let fut = h3_client.open_websocket_tunnel(h3_url.as_str(), self.headers.to_vec());
412 if let Some(total_timeout) = self.client.timeouts.total {
413 tokio_timeout(total_timeout, fut)
414 .await
415 .map_err(|_| Error::TotalTimeout(total_timeout))?
416 } else {
417 fut.await
418 }
419 }
420}
421
422impl<'a> RequestBuilder<'a> {
423 fn new(client: &'a Client, method: Method, url: impl IntoUrl) -> Self {
424 let mut error = None;
425 let url = match url.into_url() {
426 Ok(url) => Some(url),
427 Err(err) => {
428 error = Some(err);
429 None
430 }
431 };
432
433 Self {
434 client,
435 url,
436 method,
437 headers: client.default_headers.clone(),
438 body: Body::Empty,
439 version: None,
440 timeout: None,
441 error,
442 }
443 }
444
445 fn set_error(&mut self, error: Error) {
446 if self.error.is_none() {
447 self.error = Some(error);
448 }
449 }
450
451 fn ensure_content_type(&mut self, value: &str) {
452 if !self.headers.contains("content-type") {
453 self.headers.insert("Content-Type", value.to_string());
454 }
455 }
456
457 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
459 self.headers.insert(key, value);
460 self
461 }
462
463 pub fn header_append(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
465 self.headers.append(key, value);
466 self
467 }
468
469 pub fn headers(mut self, headers: impl Into<Headers>) -> Self {
471 self.headers = headers.into();
472 self
473 }
474
475 pub fn body(mut self, body: impl Into<Body>) -> Self {
477 self.body = body.into();
478 self
479 }
480
481 pub fn query<T: Serialize + ?Sized>(mut self, query: &T) -> Self {
483 if self.error.is_some() {
484 return self;
485 }
486
487 let url = match self.url.as_mut() {
488 Some(url) => url,
489 None => return self,
490 };
491
492 match serde_urlencoded::to_string(query) {
493 Ok(encoded) => {
494 if !encoded.is_empty() {
495 let merged = match url.query() {
496 Some(existing) if !existing.is_empty() => {
497 format!("{}&{}", existing, encoded)
498 }
499 _ => encoded,
500 };
501 url.set_query(Some(&merged));
502 }
503 }
504 Err(err) => self.set_error(err.into()),
505 }
506
507 self
508 }
509
510 pub fn json<T: Serialize + ?Sized>(mut self, json: &T) -> Self {
512 if self.error.is_some() {
513 return self;
514 }
515
516 match serde_json::to_vec(json) {
517 Ok(bytes) => {
518 self.body = Body::Json(bytes);
519 self.ensure_content_type("application/json");
520 }
521 Err(err) => self.set_error(err.into()),
522 }
523
524 self
525 }
526
527 pub fn form<T: Serialize + ?Sized>(mut self, form: &T) -> Self {
529 if self.error.is_some() {
530 return self;
531 }
532
533 match serde_urlencoded::to_string(form) {
534 Ok(encoded) => {
535 self.body = Body::Form(encoded);
536 self.ensure_content_type("application/x-www-form-urlencoded");
537 }
538 Err(err) => self.set_error(err.into()),
539 }
540
541 self
542 }
543
544 pub fn bearer_auth(mut self, token: impl AsRef<str>) -> Self {
546 self.headers
547 .insert("Authorization", format!("Bearer {}", token.as_ref()));
548 self
549 }
550
551 pub fn basic_auth<P: AsRef<str>>(
553 mut self,
554 username: impl AsRef<str>,
555 password: Option<P>,
556 ) -> Self {
557 let creds = match password {
558 Some(p) => format!("{}:{}", username.as_ref(), p.as_ref()),
559 None => format!("{}:", username.as_ref()),
560 };
561 let encoded = base64::engine::general_purpose::STANDARD.encode(creds.as_bytes());
562 self.headers
563 .insert("Authorization", format!("Basic {}", encoded));
564 self
565 }
566
567 pub fn timeout(mut self, timeout: Duration) -> Self {
569 self.timeout = Some(timeout);
570 self
571 }
572
573 pub fn version(mut self, version: HttpVersion) -> Self {
575 self.version = Some(version);
576 self
577 }
578
579 pub fn build(self) -> Result<Request> {
581 if let Some(error) = self.error {
582 return Err(error);
583 }
584
585 let url = self.url.ok_or_else(|| Error::missing("url"))?;
586
587 Ok(Request {
588 method: self.method,
589 url,
590 headers: self.headers,
591 body: self.body,
592 version: self.version,
593 timeout: self.timeout,
594 })
595 }
596
597 pub async fn send(self) -> Result<Response> {
599 let client = self.client.clone();
600 let request = self.build()?;
601 client.execute(request).await
602 }
603
604 pub async fn send_streaming(
608 self,
609 ) -> Result<(
610 Response,
611 tokio::sync::mpsc::Receiver<std::result::Result<Bytes, crate::transport::h2::H2Error>>,
612 )> {
613 let client = self.client.clone();
614 let request = self.build()?;
615 let mut timeouts = client.timeouts.clone();
616 if let Some(total) = request.timeout {
617 timeouts.total = Some(total);
618 }
619 let mut headers = request.headers.clone();
620
621 if let Some(jar) = &client.cookie_store {
622 if !headers.contains("cookie") {
623 if let Some(cookie_header) =
624 jar.read().await.build_cookie_header(request.url.as_str())
625 {
626 headers.insert("Cookie", cookie_header);
627 }
628 }
629 }
630
631 let version = request.version.unwrap_or(client.default_version);
632
633 if !matches!(version, HttpVersion::Http2 | HttpVersion::Auto) {
635 return Err(Error::HttpProtocol(
636 "Streaming only supported for HTTP/2".into(),
637 ));
638 }
639
640 let uri: Uri = request
642 .url
643 .as_str()
644 .parse()
645 .map_err(|e| Error::HttpProtocol(format!("Invalid URI: {}", e)))?;
646
647 let connector = client.connector_for_uri(&uri);
650 let connect_fut = connector.connect(&uri);
651 let stream = if let Some(connect_timeout) = timeouts.connect {
652 tokio_timeout(connect_timeout, connect_fut)
653 .await
654 .map_err(|_| Error::ConnectTimeout(connect_timeout))??
655 } else {
656 connect_fut.await?
657 };
658
659 let alpn = stream.alpn_protocol();
661 if !alpn.is_h2() {
662 return Err(Error::HttpProtocol(format!(
663 "Expected h2 ALPN, got {:?}",
664 alpn
665 )));
666 }
667
668 let h2_connect_fut =
670 H2Connection::connect(stream, client.http2_settings.clone(), client.pseudo_order);
671 let mut h2_conn = if let Some(connect_timeout) = timeouts.connect {
672 tokio_timeout(connect_timeout, h2_connect_fut)
673 .await
674 .map_err(|_| Error::ConnectTimeout(connect_timeout))??
675 } else {
676 h2_connect_fut.await?
677 };
678
679 let mut path = request.url.path().to_string();
681 if path.is_empty() {
682 path = "/".to_string();
683 }
684 if let Some(query) = request.url.query() {
685 path.push('?');
686 path.push_str(query);
687 }
688
689 let host = request.url.host_str().unwrap_or("localhost");
690 let authority = if let Some(port) = request.url.port_or_known_default() {
691 if port == 443 {
692 host.to_string()
693 } else {
694 format!("{}:{}", host, port)
695 }
696 } else {
697 host.to_string()
698 };
699
700 let full_uri = format!("https://{}{}", authority, path);
702 let mut request_builder = http::Request::builder()
703 .method(request.method.clone())
704 .uri(&full_uri);
705
706 for (key, value) in headers.iter() {
708 request_builder = request_builder.header(key, value);
709 }
710
711 let body = request.body.clone().into_bytes()?;
712 let http_request = request_builder
713 .body(body)
714 .map_err(|e| Error::HttpProtocol(format!("Failed to build request: {}", e)))?;
715
716 let send_fut = h2_conn.send_request_streaming(http_request);
718 let (response, rx) = if let Some(ttfb_timeout) = timeouts.ttfb {
719 tokio_timeout(ttfb_timeout, send_fut)
720 .await
721 .map_err(|_| Error::TtfbTimeout(ttfb_timeout))??
722 } else {
723 send_fut.await?
724 };
725
726 tokio::spawn(async move {
728 loop {
729 match h2_conn.read_streaming_frames().await {
730 Ok(true) => continue,
731 Ok(false) => break,
732 Err(e) => {
733 tracing::debug!("Streaming read error: {}", e);
734 break;
735 }
736 }
737 }
738 });
739
740 let status = response.status().as_u16();
742 let headers = response
743 .headers()
744 .iter()
745 .map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
746 .collect::<Vec<(String, String)>>();
747
748 let our_response = crate::response::Response::new(
749 status,
750 Headers::from(headers),
751 Bytes::new(), "HTTP/2".to_string(),
753 );
754
755 let request_url = request.url.clone();
756 let our_response = our_response.with_url(request_url.clone());
757
758 if let Some(jar) = &client.cookie_store {
759 jar.write()
760 .await
761 .store_from_headers(our_response.headers(), request_url.as_str());
762 }
763
764 Ok((our_response, rx))
765 }
766}
767
768impl Client {
769 pub async fn execute(&self, mut request: Request) -> Result<Response> {
771 let policy = self.redirect_policy.clone();
772 let mut redirects = 0u32;
773
774 loop {
775 let mut headers = request.headers.clone();
776 let cookie_injected = self.apply_cookie_header(&request, &mut headers).await;
777 request.headers = headers;
778
779 let mut timeouts = self.timeouts.clone();
780 if let Some(total) = request.timeout {
781 timeouts.total = Some(total);
782 }
783
784 let response = self.execute_once(&request, &timeouts).await?;
785
786 self.store_cookies(&response, &request.url).await;
787
788 if matches!(policy, RedirectPolicy::None) || !response.is_redirect() {
789 return Ok(response);
790 }
791
792 let location = match response.redirect_url() {
793 Some(value) => value,
794 None => return Ok(response),
795 };
796
797 if let RedirectPolicy::Limited(limit) = policy {
798 if redirects >= limit {
799 return Err(Error::RedirectLimit { count: limit });
800 }
801 }
802
803 let next_url = request.url.join(location).map_err(Error::from)?;
804 let mut next_request = self.redirect_request(&request, &response, next_url);
805
806 if cookie_injected {
807 next_request.headers.remove("cookie");
808 }
809
810 request = next_request;
811 redirects += 1;
812 }
813 }
814
815 async fn execute_once(&self, request: &Request, timeouts: &Timeouts) -> Result<Response> {
816 let version = request.version.unwrap_or(self.default_version);
817
818 if matches!(version, HttpVersion::Http3Only) {
820 return self
821 .send_h3_for_url(request, request.url.clone(), timeouts)
822 .await;
823 }
824
825 if matches!(version, HttpVersion::Http3) {
827 match self
828 .send_h3_for_url(request, request.url.clone(), timeouts)
829 .await
830 {
831 Ok(response) => return Ok(response),
832 Err(e) => {
833 tracing::debug!("HTTP/3 failed, falling back to HTTP/1.1 or HTTP/2: {}", e);
834 }
836 }
837 }
838
839 if matches!(version, HttpVersion::Auto) && self.h3_upgrade_enabled {
841 let origin = Self::origin_for_url(&request.url);
842 if let Some(alt_svc) = self.alt_svc_cache.get_h3_alternative(&origin).await {
843 tracing::debug!(
844 "Alt-Svc indicates HTTP/3 support for {}, attempting upgrade",
845 origin
846 );
847
848 let mut h3_url = request.url.clone();
849 let _ = h3_url.set_scheme("https");
850 if let Some(ref host) = alt_svc.host {
851 h3_url
852 .set_host(Some(host))
853 .map_err(|_| Error::HttpProtocol("Invalid Alt-Svc host".into()))?;
854 }
855 let _ = h3_url.set_port(Some(alt_svc.port));
856
857 match self
858 .send_h3_for_url(request, h3_url.clone(), timeouts)
859 .await
860 {
861 Ok(response) => return Ok(response.with_url(h3_url)),
862 Err(e) => {
863 tracing::debug!("HTTP/3 upgrade failed, using HTTP/1.1 or HTTP/2: {}", e);
864 }
866 }
867 }
868 }
869
870 self.send_h1_h2(request, version, timeouts).await
872 }
873
874 async fn send_h3_for_url(
875 &self,
876 request: &Request,
877 url: Url,
878 timeouts: &Timeouts,
879 ) -> Result<Response> {
880 let body = if request.body.is_empty() {
881 None
882 } else {
883 Some(request.body.clone().into_bytes()?.to_vec())
884 };
885
886 let fut = self.h3_client.send_request(
887 url.as_str(),
888 request.method.as_str(),
889 request.headers.to_vec(),
890 body,
891 );
892
893 let response = if let Some(total_timeout) = timeouts.total {
895 tokio_timeout(total_timeout, fut)
896 .await
897 .map_err(|_| Error::TotalTimeout(total_timeout))??
898 } else {
899 fut.await?
900 };
901
902 Ok(response.with_url(url))
903 }
904
905 async fn send_h1_h2(
906 &self,
907 request: &Request,
908 version: HttpVersion,
909 timeouts: &Timeouts,
910 ) -> Result<Response> {
911 let request_url = request.url.clone();
913
914 let uri: Uri = request
916 .url
917 .as_str()
918 .parse()
919 .map_err(|e| Error::HttpProtocol(format!("Invalid URI: {}", e)))?;
920
921 let prefer_http2 = match version {
923 HttpVersion::Http1_1 => false,
924 HttpVersion::Http2 => true,
925 HttpVersion::Http3 | HttpVersion::Http3Only => {
926 return Err(Error::HttpProtocol("HTTP/3 should use send_h3".into()));
927 }
928 HttpVersion::Auto => matches!(self.default_version, HttpVersion::Http2),
929 };
930
931 let h3_upgrade_enabled = self.h3_upgrade_enabled;
933 let alt_svc_cache = self.alt_svc_cache.clone();
934 let origin = Self::origin_for_url(&request.url);
935
936 let headers_vec = request.headers.to_vec();
937 let body_bytes = if request.body.is_empty() {
938 None
939 } else {
940 Some(request.body.clone().into_bytes()?)
941 };
942
943 if prefer_http2 {
945 let pool_key = Self::make_pool_key(&uri);
946
947 let pooled = {
949 let pool = self.h2_pool.read().await;
950 pool.get(&pool_key).cloned()
951 };
952
953 if let Some(conn) = pooled {
954 let result = conn
956 .send_request(
957 request.method.clone(),
958 &uri,
959 headers_vec.clone(),
960 body_bytes.clone(),
961 )
962 .await;
963
964 match result {
965 Ok(response) => {
966 if h3_upgrade_enabled {
968 if let Some(alt_svc) = response.get_header("alt-svc") {
969 alt_svc_cache.parse_and_store(&origin, alt_svc).await;
970 }
971 }
972 return Ok(response.with_url(request_url));
973 }
974 Err(e) => {
975 tracing::debug!("Pooled HTTP/2 connection failed, creating new: {}", e);
977 let mut pool = self.h2_pool.write().await;
978 pool.remove(&pool_key);
979 }
980 }
981 }
982
983 let connector = self.connector_for_uri(&uri);
986 let connect_fut = connector.connect(&uri);
987 let stream = if let Some(connect_timeout) = timeouts.connect {
988 tokio_timeout(connect_timeout, connect_fut)
989 .await
990 .map_err(|_| Error::ConnectTimeout(connect_timeout))??
991 } else {
992 connect_fut.await?
993 };
994
995 let use_http2 = if self.http2_prior_knowledge && !stream.alpn_protocol().is_h2() {
997 true
999 } else if let MaybeHttpsStream::Https(ref ssl_stream) = stream {
1000 ssl_stream.ssl().selected_alpn_protocol() == Some(b"h2")
1001 } else {
1002 false
1003 };
1004
1005 if use_http2 {
1006 let h2_conn =
1008 H2Connection::connect(stream, self.http2_settings.clone(), self.pseudo_order)
1009 .await?;
1010 let pooled_conn = H2PooledConnection::new(h2_conn);
1011
1012 {
1014 let mut pool = self.h2_pool.write().await;
1015 pool.insert(pool_key, pooled_conn.clone());
1016 }
1017
1018 let fut = pooled_conn.send_request(
1020 request.method.clone(),
1021 &uri,
1022 headers_vec.clone(),
1023 body_bytes.clone(),
1024 );
1025
1026 let response = if let Some(ttfb_timeout) = timeouts.ttfb {
1027 tokio_timeout(ttfb_timeout, fut)
1028 .await
1029 .map_err(|_| Error::TtfbTimeout(ttfb_timeout))?
1030 } else {
1031 fut.await
1032 }?;
1033
1034 if h3_upgrade_enabled {
1036 if let Some(alt_svc) = response.get_header("alt-svc") {
1037 alt_svc_cache.parse_and_store(&origin, alt_svc).await;
1038 }
1039 }
1040
1041 return Ok(response.with_url(request_url));
1042 }
1043 }
1045
1046 let pool_key = Self::make_pool_key(&uri);
1048
1049 let mut stream_opt = self.h1_pool.get_h1(&pool_key).await;
1051 let mut used_pooled = stream_opt.is_some();
1052
1053 let mut stream = if let Some(pooled_stream) = stream_opt.take() {
1055 tracing::debug!("H1: Reusing pooled connection for {:?}", pool_key);
1056 pooled_stream
1057 } else {
1058 tracing::debug!("H1: Creating new connection for {:?}", pool_key);
1059 let connector = self.connector_for_uri(&uri);
1061 let connect_fut = connector.connect(&uri);
1062 if let Some(connect_timeout) = timeouts.connect {
1063 tokio_timeout(connect_timeout, connect_fut)
1064 .await
1065 .map_err(|_| Error::ConnectTimeout(connect_timeout))??
1066 } else {
1067 connect_fut.await?
1068 }
1069 };
1070
1071 let server_wants_h2 = if let MaybeHttpsStream::Https(ref ssl_stream) = stream {
1074 ssl_stream.ssl().selected_alpn_protocol() == Some(b"h2")
1075 } else {
1076 false
1077 };
1078
1079 let response = if server_wants_h2 {
1080 tracing::debug!("Server selected h2 via ALPN, upgrading to HTTP/2");
1082
1083 let h2_conn =
1084 H2Connection::connect(stream, self.http2_settings.clone(), self.pseudo_order)
1085 .await?;
1086 let pooled_conn = H2PooledConnection::new(h2_conn);
1087
1088 {
1090 let mut pool = self.h2_pool.write().await;
1091 pool.insert(pool_key, pooled_conn.clone());
1092 }
1093
1094 let fut = pooled_conn.send_request(
1096 request.method.clone(),
1097 &uri,
1098 headers_vec.clone(),
1099 body_bytes.clone(),
1100 );
1101
1102 if let Some(ttfb_timeout) = timeouts.ttfb {
1103 tokio_timeout(ttfb_timeout, fut)
1104 .await
1105 .map_err(|_| Error::TtfbTimeout(ttfb_timeout))?
1106 } else {
1107 fut.await
1108 }?
1109 } else {
1110 let result = loop {
1114 let stream_for_request = stream;
1115 let fut = Self::do_send_http1(
1116 stream_for_request,
1117 request.method.clone(),
1118 &uri,
1119 headers_vec.clone(),
1120 body_bytes.clone(),
1121 );
1122
1123 let request_result = if let Some(ttfb_timeout) = timeouts.ttfb {
1125 tokio_timeout(ttfb_timeout, fut)
1126 .await
1127 .map_err(|_| Error::TtfbTimeout(ttfb_timeout))?
1128 } else {
1129 fut.await
1130 };
1131
1132 match request_result {
1133 Ok((resp, returned_stream)) => {
1134 self.h1_pool.put_h1(pool_key.clone(), returned_stream).await;
1136 break Ok(resp);
1137 }
1138 Err(e) => {
1139 if used_pooled {
1141 tracing::debug!(
1142 "H1: Pooled connection failed for {:?}, creating new: {}",
1143 pool_key,
1144 e
1145 );
1146 let connector = self.connector_for_uri(&uri);
1148 let connect_fut = connector.connect(&uri);
1149 stream = if let Some(connect_timeout) = timeouts.connect {
1150 tokio_timeout(connect_timeout, connect_fut)
1151 .await
1152 .map_err(|_| Error::ConnectTimeout(connect_timeout))??
1153 } else {
1154 connect_fut.await?
1155 };
1156 used_pooled = false; continue;
1158 } else {
1159 tracing::debug!(
1161 "H1: Request failed for {:?}, discarding connection: {}",
1162 pool_key,
1163 e
1164 );
1165 break Err(e);
1166 }
1167 }
1168 }
1169 };
1170
1171 result?
1172 };
1173
1174 if h3_upgrade_enabled {
1176 if let Some(alt_svc) = response.get_header("alt-svc") {
1177 alt_svc_cache.parse_and_store(&origin, alt_svc).await;
1178 }
1179 }
1180
1181 Ok(response.with_url(request_url))
1182 }
1183
1184 fn redirect_request(&self, request: &Request, response: &Response, next_url: Url) -> Request {
1185 let status = response.status().as_u16();
1186 let mut method = request.method.clone();
1187 let mut body = request.body.clone();
1188 let mut headers = request.headers.clone();
1189
1190 let should_switch = status == 303
1191 || ((status == 301 || status == 302) && !matches!(method, Method::GET | Method::HEAD));
1192
1193 if should_switch {
1194 method = Method::GET;
1195 body = Body::Empty;
1196 headers.remove("content-length");
1197 headers.remove("content-type");
1198 }
1199
1200 if Self::is_cross_origin(&request.url, &next_url) {
1201 headers.remove("authorization");
1202 }
1203
1204 Request {
1205 method,
1206 url: next_url,
1207 headers,
1208 body,
1209 version: request.version,
1210 timeout: request.timeout,
1211 }
1212 }
1213
1214 async fn apply_cookie_header(&self, request: &Request, headers: &mut Headers) -> bool {
1215 if let Some(jar) = &self.cookie_store {
1216 if !headers.contains("cookie") {
1217 if let Some(cookie_header) =
1218 jar.read().await.build_cookie_header(request.url.as_str())
1219 {
1220 headers.insert("Cookie", cookie_header);
1221 return true;
1222 }
1223 }
1224 }
1225 false
1226 }
1227
1228 async fn store_cookies(&self, response: &Response, url: &Url) {
1229 if let Some(jar) = &self.cookie_store {
1230 jar.write()
1231 .await
1232 .store_from_headers(response.headers(), url.as_str());
1233 }
1234 }
1235
1236 fn make_pool_key(uri: &Uri) -> PoolKey {
1238 let host = uri.host().unwrap_or("localhost").to_string();
1239 let is_https = uri.scheme_str() == Some("https");
1240 let port = uri.port_u16().unwrap_or(if is_https { 443 } else { 80 });
1241 PoolKey::new(host, port, is_https)
1242 }
1243
1244 async fn do_send_http1(
1245 stream: MaybeHttpsStream,
1246 method: Method,
1247 uri: &Uri,
1248 headers: Vec<(String, String)>,
1249 body: Option<Bytes>,
1250 ) -> Result<(Response, MaybeHttpsStream)> {
1251 let mut conn = H1Connection::new(stream);
1252 let response = conn.send_request(method, uri, headers, body).await?;
1253 let stream = conn.into_inner();
1254 Ok((response, stream))
1255 }
1256
1257 fn origin_for_url(url: &Url) -> String {
1259 let scheme = url.scheme();
1260 let host = url.host_str().unwrap_or("localhost");
1261 let port = url
1262 .port_or_known_default()
1263 .unwrap_or(if scheme == "https" { 443 } else { 80 });
1264
1265 if (scheme == "https" && port == 443) || (scheme == "http" && port == 80) {
1266 format!("{}://{}", scheme, host)
1267 } else {
1268 format!("{}://{}:{}", scheme, host, port)
1269 }
1270 }
1271
1272 fn is_cross_origin(a: &Url, b: &Url) -> bool {
1273 a.scheme() != b.scheme()
1274 || a.host_str() != b.host_str()
1275 || a.port_or_known_default() != b.port_or_known_default()
1276 }
1277}
1278
1279impl ClientBuilder {
1280 pub fn new() -> Self {
1289 Self {
1290 fingerprint: FingerprintProfile::default(),
1291 http2_settings: None,
1292 pseudo_order: PseudoHeaderOrder::Chrome,
1293 timeouts: Timeouts::default(),
1294 prefer_http2: true, h3_upgrade_enabled: true, http2_prior_knowledge: false,
1297 root_certs: Vec::new(),
1298 use_platform_roots: false,
1299 danger_accept_invalid_certs: false,
1300 localhost_allows_invalid_certs: true, default_headers: Headers::new(),
1302 redirect_policy: RedirectPolicy::None,
1303 cookie_store: None,
1304 }
1305 }
1306
1307 pub fn fingerprint(mut self, fingerprint: FingerprintProfile) -> Self {
1309 self.fingerprint = fingerprint;
1310 self
1311 }
1312
1313 pub fn http2_settings(mut self, settings: Http2Settings) -> Self {
1315 self.http2_settings = Some(settings);
1316 self
1317 }
1318
1319 pub fn pseudo_order(mut self, order: PseudoHeaderOrder) -> Self {
1321 self.pseudo_order = order;
1322 self
1323 }
1324
1325 pub fn timeouts(mut self, timeouts: Timeouts) -> Self {
1329 self.timeouts = timeouts;
1330 self
1331 }
1332
1333 pub fn api_timeouts(mut self) -> Self {
1337 self.timeouts = Timeouts::api_defaults();
1338 self
1339 }
1340
1341 pub fn streaming_timeouts(mut self) -> Self {
1346 self.timeouts = Timeouts::streaming_defaults();
1347 self
1348 }
1349
1350 #[deprecated(
1355 since = "1.0.2",
1356 note = "Use `timeouts()` or `total_timeout()` instead"
1357 )]
1358 pub fn timeout(mut self, timeout: Duration) -> Self {
1359 self.timeouts.total = Some(timeout);
1360 self
1361 }
1362
1363 pub fn total_timeout(mut self, timeout: Duration) -> Self {
1365 self.timeouts.total = Some(timeout);
1366 self
1367 }
1368
1369 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
1371 self.timeouts.connect = Some(timeout);
1372 self
1373 }
1374
1375 pub fn ttfb_timeout(mut self, timeout: Duration) -> Self {
1377 self.timeouts.ttfb = Some(timeout);
1378 self
1379 }
1380
1381 pub fn read_timeout(mut self, timeout: Duration) -> Self {
1383 self.timeouts.read_idle = Some(timeout);
1384 self
1385 }
1386
1387 pub fn write_timeout(mut self, timeout: Duration) -> Self {
1389 self.timeouts.write_idle = Some(timeout);
1390 self
1391 }
1392
1393 pub fn pool_acquire_timeout(mut self, timeout: Duration) -> Self {
1395 self.timeouts.pool_acquire = Some(timeout);
1396 self
1397 }
1398
1399 pub fn default_headers(mut self, headers: impl Into<Headers>) -> Self {
1401 self.default_headers = headers.into();
1402 self
1403 }
1404
1405 pub fn default_header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
1407 self.default_headers.insert(name, value);
1408 self
1409 }
1410
1411 pub fn user_agent(mut self, value: impl Into<String>) -> Self {
1413 self.default_headers.insert("User-Agent", value.into());
1414 self
1415 }
1416
1417 pub fn redirect_policy(mut self, policy: RedirectPolicy) -> Self {
1419 self.redirect_policy = policy;
1420 self
1421 }
1422
1423 pub fn cookie_store(mut self, enabled: bool) -> Self {
1425 if enabled {
1426 self.cookie_store = Some(Arc::new(RwLock::new(CookieJar::new())));
1427 } else {
1428 self.cookie_store = None;
1429 }
1430 self
1431 }
1432
1433 pub fn cookie_jar(mut self, jar: Arc<RwLock<CookieJar>>) -> Self {
1435 self.cookie_store = Some(jar);
1436 self
1437 }
1438
1439 pub fn prefer_http2(mut self, prefer: bool) -> Self {
1441 self.prefer_http2 = prefer;
1442 self
1443 }
1444
1445 pub fn h3_upgrade(mut self, enabled: bool) -> Self {
1452 self.h3_upgrade_enabled = enabled;
1453 self
1454 }
1455
1456 pub fn http2_prior_knowledge(mut self, enabled: bool) -> Self {
1459 self.http2_prior_knowledge = enabled;
1460 if enabled {
1462 self.prefer_http2 = true;
1463 }
1464 self
1465 }
1466
1467 pub fn add_root_certificate(mut self, cert: Vec<u8>) -> Self {
1469 self.root_certs.push(cert);
1470 self
1471 }
1472
1473 pub fn with_platform_roots(mut self, enabled: bool) -> Self {
1484 self.use_platform_roots = enabled;
1485 self
1486 }
1487
1488 pub fn danger_accept_invalid_certs(mut self, accept: bool) -> Self {
1494 self.danger_accept_invalid_certs = accept;
1495 self
1496 }
1497
1498 pub fn localhost_allows_invalid_certs(mut self, allow: bool) -> Self {
1506 self.localhost_allows_invalid_certs = allow;
1507 self
1508 }
1509
1510 pub fn build(self) -> Result<Client> {
1512 let tls_fingerprint = self.fingerprint.tls_fingerprint();
1514 let mut connector = BoringConnector::with_fingerprint(tls_fingerprint.clone())
1515 .with_root_certificates(self.root_certs.clone())
1516 .with_platform_roots(self.use_platform_roots);
1517
1518 if self.danger_accept_invalid_certs {
1520 connector = connector.danger_accept_invalid_certs(true);
1521 }
1522
1523 let insecure_connector = BoringConnector::with_fingerprint(tls_fingerprint.clone())
1525 .with_root_certificates(self.root_certs)
1526 .with_platform_roots(self.use_platform_roots)
1527 .danger_accept_invalid_certs(true);
1528
1529 let mut h3_client = H3Client::with_fingerprint(tls_fingerprint);
1531 if self.danger_accept_invalid_certs {
1532 h3_client = h3_client.danger_accept_invalid_certs(true);
1533 }
1534
1535 let http2_settings = self.http2_settings.unwrap_or_default();
1537
1538 let default_version = if self.prefer_http2 {
1540 HttpVersion::Http2
1541 } else {
1542 HttpVersion::Http1_1
1543 };
1544
1545 Ok(Client {
1546 connector,
1547 insecure_connector,
1548 h3_client,
1549 alt_svc_cache: Arc::new(AltSvcCache::new()),
1550 h2_pool: Arc::new(RwLock::new(HashMap::new())),
1551 h1_pool: Arc::new(ConnectionPool::new()),
1552 http2_settings,
1553 pseudo_order: self.pseudo_order,
1554 default_version,
1555 timeouts: self.timeouts,
1556 h3_upgrade_enabled: self.h3_upgrade_enabled,
1557 http2_prior_knowledge: self.http2_prior_knowledge,
1558 danger_accept_invalid_certs: self.danger_accept_invalid_certs,
1559 localhost_allows_invalid_certs: self.localhost_allows_invalid_certs,
1560 default_headers: self.default_headers,
1561 redirect_policy: self.redirect_policy,
1562 cookie_store: self.cookie_store,
1563 })
1564 }
1565}
1566
1567impl Default for ClientBuilder {
1568 fn default() -> Self {
1569 Self::new()
1570 }
1571}
1572
1573impl Default for AltSvcCache {
1574 fn default() -> Self {
1575 Self::new()
1576 }
1577}