statsig_rust/networking/providers/
net_provider_reqwest.rs1use std::collections::HashMap;
2use std::time::Duration;
3
4use async_trait::async_trait;
5
6use crate::{
7 log_e, log_w,
8 networking::{
9 http_types::{HttpMethod, RequestArgs, Response},
10 NetworkProvider,
11 },
12};
13
14use crate::networking::proxy_config::ProxyConfig;
15use reqwest::Method;
16
17const TAG: &str = "NetworkProviderReqwest";
18
19pub struct NetworkProviderReqwest {}
20
21#[async_trait]
22impl NetworkProvider for NetworkProviderReqwest {
23 async fn send(&self, method: &HttpMethod, args: &RequestArgs) -> Response {
24 if let Some(is_shutdown) = &args.is_shutdown {
25 if is_shutdown.load(std::sync::atomic::Ordering::SeqCst) {
26 return Response {
27 status_code: 0,
28 data: None,
29 error: Some("Request was shutdown".to_string()),
30 headers: None,
31 };
32 }
33 }
34
35 let request = self.build_request(method, args);
36
37 let error;
38 let mut status_code = 0;
39 let mut data = None;
40 let mut headers = None;
41
42 match request.send().await {
43 Ok(response) => {
44 status_code = response.status().as_u16();
45 headers = get_response_headers(&response);
46 data = response.bytes().await.ok().map(|bytes| bytes.to_vec());
47 error = None;
48 }
49 Err(e) => {
50 let error_message = get_error_message(e);
51 log_w!(TAG, "Request Error: {} {}", &args.url, error_message);
52 error = Some(error_message);
53 }
54 }
55
56 Response {
57 status_code,
58 data,
59 error,
60 headers,
61 }
62 }
63}
64
65impl NetworkProviderReqwest {
66 fn build_request(
67 &self,
68 method: &HttpMethod,
69 request_args: &RequestArgs,
70 ) -> reqwest::RequestBuilder {
71 let method_actual = match method {
72 HttpMethod::GET => Method::GET,
73 HttpMethod::POST => Method::POST,
74 };
75 let is_post = method_actual == Method::POST;
76
77 let mut client_builder = reqwest::Client::builder();
78
79 if let Some(proxy_config) = request_args.proxy_config.as_ref() {
81 client_builder = Self::configure_proxy(client_builder, proxy_config);
82 }
83
84 let client = client_builder.build().unwrap_or_else(|e| {
85 log_e!(TAG, "Failed to build reqwest client with proxy config: {}. Falling back to default client.", e);
86 reqwest::Client::new()
87 });
88
89 let mut request = client.request(method_actual, &request_args.url);
90
91 let timeout_duration = match request_args.timeout_ms > 0 {
92 true => Duration::from_millis(request_args.timeout_ms),
93 false => Duration::from_secs(10),
94 };
95 request = request.timeout(timeout_duration);
96
97 if let Some(headers) = &request_args.headers {
98 for (key, value) in headers {
99 request = request.header(key, value);
100 }
101 }
102
103 if let Some(params) = &request_args.query_params {
104 request = request.query(params);
105 }
106
107 if is_post {
108 let bytes = match &request_args.body {
109 Some(b) => b.clone(),
110 None => vec![],
111 };
112 let byte_len = bytes.len();
113
114 request = request.body(bytes);
115 request = request.header("Content-Length", byte_len.to_string());
116 }
117
118 request
119 }
120
121 fn configure_proxy(
122 client_builder: reqwest::ClientBuilder,
123 proxy_config: &ProxyConfig,
124 ) -> reqwest::ClientBuilder {
125 let (Some(host), Some(port)) = (&proxy_config.proxy_host, &proxy_config.proxy_port) else {
126 return client_builder;
127 };
128
129 let proxy_url = format!(
130 "{}://{}:{}",
131 proxy_config.proxy_protocol.as_deref().unwrap_or("http"),
132 host,
133 port
134 );
135
136 let Ok(proxy) = reqwest::Proxy::all(&proxy_url) else {
137 log_w!(TAG, "Failed to create proxy for URL: {}", proxy_url);
138 return client_builder;
139 };
140
141 let Some(auth) = &proxy_config.proxy_auth else {
142 return client_builder.proxy(proxy);
143 };
144
145 let Some((username, password)) = auth.split_once(':') else {
146 log_w!(
147 TAG,
148 "Invalid proxy auth format. Expected 'username:password'"
149 );
150 return client_builder.proxy(proxy);
151 };
152
153 client_builder.proxy(proxy.basic_auth(username, password))
154 }
155}
156
157fn get_error_message(error: reqwest::Error) -> String {
158 let mut error_message = error.to_string();
159
160 if let Some(url_error) = error.url() {
161 error_message.push_str(&format!(". URL: {}", url_error));
162 }
163
164 if let Some(status_error) = error.status() {
165 error_message.push_str(&format!(". Status: {}", status_error));
166 }
167
168 error_message
169}
170
171fn get_response_headers(response: &reqwest::Response) -> Option<HashMap<String, String>> {
172 let headers = response.headers();
173 if headers.is_empty() {
174 return None;
175 }
176
177 let mut headers_map = HashMap::new();
178 for (key, value) in headers {
179 headers_map.insert(key.to_string(), value.to_str().unwrap_or("").to_string());
180 }
181
182 Some(headers_map)
183}