statsig_rust/networking/providers/
net_provider_reqwest.rs1use std::collections::HashMap;
2use std::io::{BufReader, Seek, SeekFrom, Write};
3use std::time::Duration;
4
5use async_trait::async_trait;
6
7use crate::log_d;
8use crate::{
9 log_e, log_w,
10 networking::{
11 http_types::{HttpMethod, RequestArgs, Response, ResponseData},
12 NetworkProvider,
13 },
14 StatsigErr,
15};
16
17use crate::networking::proxy_config::ProxyConfig;
18use reqwest::Method;
19
20const TAG: &str = "NetworkProviderReqwest";
21
22pub struct NetworkProviderReqwest {
23 has_file_write_access: bool,
24}
25
26impl NetworkProviderReqwest {
27 pub fn new() -> Self {
28 Self {
29 has_file_write_access: tempfile::tempfile().is_ok(),
30 }
31 }
32}
33
34impl Default for NetworkProviderReqwest {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40#[async_trait]
41impl NetworkProvider for NetworkProviderReqwest {
42 async fn send(&self, method: &HttpMethod, args: &RequestArgs) -> Response {
43 if let Some(is_shutdown) = &args.is_shutdown {
44 if is_shutdown.load(std::sync::atomic::Ordering::SeqCst) {
45 return Response {
46 status_code: None,
47 data: None,
48 error: Some("Request was shutdown".to_string()),
49 };
50 }
51 }
52
53 let request = self.build_request(method, args);
54
55 let mut error = None;
56 let mut status_code = None;
57 let mut data = None;
58
59 match request.send().await {
60 Ok(response) => {
61 status_code = Some(response.status().as_u16());
62
63 let data_result =
64 if !self.has_file_write_access || args.disable_file_streaming == Some(true) {
65 Self::write_response_to_in_memory_buffer(response).await
66 } else {
67 Self::write_response_to_temp_file(response).await
68 };
69
70 match data_result {
71 Ok(response_data) => data = Some(response_data),
72 Err(e) => {
73 error = Some(e.to_string());
74 }
75 }
76 }
77 Err(e) => {
78 let error_message = get_error_message(e);
79 error = Some(error_message);
80 }
81 }
82
83 Response {
84 status_code,
85 data,
86 error,
87 }
88 }
89}
90
91impl NetworkProviderReqwest {
92 fn build_request(
93 &self,
94 method: &HttpMethod,
95 request_args: &RequestArgs,
96 ) -> reqwest::RequestBuilder {
97 let method_actual = match method {
98 HttpMethod::GET => Method::GET,
99 HttpMethod::POST => Method::POST,
100 };
101 let is_post = method_actual == Method::POST;
102
103 let mut client_builder = reqwest::Client::builder();
104
105 if let Some(proxy_config) = request_args.proxy_config.as_ref() {
107 client_builder = Self::configure_proxy(client_builder, proxy_config);
108 }
109
110 let client = client_builder.build().unwrap_or_else(|e| {
111 log_e!(TAG, "Failed to build reqwest client with proxy config: {}. Falling back to default client.", e);
112 reqwest::Client::new()
113 });
114
115 let mut request = client.request(method_actual, &request_args.url);
116
117 let timeout_duration = match request_args.timeout_ms > 0 {
118 true => Duration::from_millis(request_args.timeout_ms),
119 false => Duration::from_secs(10),
120 };
121 request = request.timeout(timeout_duration);
122
123 if let Some(headers) = &request_args.headers {
124 for (key, value) in headers {
125 request = request.header(key, value);
126 }
127 }
128
129 if let Some(params) = &request_args.query_params {
130 request = request.query(params);
131 }
132
133 if is_post {
134 let bytes = match &request_args.body {
135 Some(b) => b.clone(),
136 None => vec![],
137 };
138 let byte_len = bytes.len();
139
140 request = request.body(bytes);
141 request = request.header("Content-Length", byte_len.to_string());
142 }
143
144 request
145 }
146
147 fn configure_proxy(
148 client_builder: reqwest::ClientBuilder,
149 proxy_config: &ProxyConfig,
150 ) -> reqwest::ClientBuilder {
151 let (Some(host), Some(port)) = (&proxy_config.proxy_host, &proxy_config.proxy_port) else {
152 return client_builder;
153 };
154
155 let proxy_url = format!(
156 "{}://{}:{}",
157 proxy_config.proxy_protocol.as_deref().unwrap_or("http"),
158 host,
159 port
160 );
161
162 let Ok(proxy) = reqwest::Proxy::all(&proxy_url) else {
163 log_w!(TAG, "Failed to create proxy for URL: {}", proxy_url);
164 return client_builder;
165 };
166
167 let Some(auth) = &proxy_config.proxy_auth else {
168 return client_builder.proxy(proxy);
169 };
170
171 let Some((username, password)) = auth.split_once(':') else {
172 log_w!(
173 TAG,
174 "Invalid proxy auth format. Expected 'username:password'"
175 );
176 return client_builder.proxy(proxy);
177 };
178
179 client_builder.proxy(proxy.basic_auth(username, password))
180 }
181
182 async fn write_response_to_temp_file(
183 response: reqwest::Response,
184 ) -> Result<ResponseData, StatsigErr> {
185 let headers = get_response_headers(&response);
186 let mut response = response;
187 let mut temp_file = tempfile::spooled_tempfile(1024 * 1024 * 2); let mut total_bytes = 0;
190 while let Some(item) = response
191 .chunk()
192 .await
193 .map_err(|e| StatsigErr::FileError(e.to_string()))?
194 {
195 total_bytes += item.len();
196 temp_file
197 .write_all(&item)
198 .map_err(|e| StatsigErr::FileError(e.to_string()))?;
199 }
200
201 temp_file
202 .seek(SeekFrom::Start(0))
203 .map_err(|e| StatsigErr::FileError(e.to_string()))?;
204
205 let reader = BufReader::new(temp_file);
206
207 log_d!(TAG, "Wrote {} bytes to spooled temp file", total_bytes);
208
209 Ok(ResponseData::from_stream_with_headers(
210 Box::new(reader),
211 headers,
212 ))
213 }
214
215 async fn write_response_to_in_memory_buffer(
216 response: reqwest::Response,
217 ) -> Result<ResponseData, StatsigErr> {
218 let headers = get_response_headers(&response);
219 let bytes = response
220 .bytes()
221 .await
222 .map_err(|e| StatsigErr::SerializationError(e.to_string()))?;
223
224 log_d!(TAG, "Wrote {} bytes to in-memory buffer", bytes.len());
225
226 Ok(ResponseData::from_bytes_with_headers(
227 bytes.to_vec(),
228 headers,
229 ))
230 }
231}
232
233fn get_error_message(error: reqwest::Error) -> String {
234 let mut error_message = error.to_string();
235
236 if let Some(url_error) = error.url() {
237 error_message.push_str(&format!(". URL: {}", url_error));
238 }
239
240 if let Some(status_error) = error.status() {
241 error_message.push_str(&format!(". Status: {}", status_error));
242 }
243
244 error_message
245}
246
247fn get_response_headers(response: &reqwest::Response) -> Option<HashMap<String, String>> {
248 let headers = response.headers();
249 if headers.is_empty() {
250 return None;
251 }
252
253 let mut headers_map = HashMap::new();
254 for (key, value) in headers {
255 headers_map.insert(key.to_string(), value.to_str().unwrap_or("").to_string());
256 }
257
258 Some(headers_map)
259}