statsig_rust/networking/providers/
net_provider_reqwest.rs1use std::collections::HashMap;
2use std::io::{BufReader, Seek, SeekFrom, Write};
3use std::time::Duration;
4
5use async_trait::async_trait;
6
7use crate::{
8 log_e, log_w,
9 networking::{
10 http_types::{HttpMethod, RequestArgs, Response, ResponseData},
11 NetworkProvider,
12 },
13 StatsigErr,
14};
15
16use crate::networking::proxy_config::ProxyConfig;
17use reqwest::Method;
18
19const TAG: &str = "NetworkProviderReqwest";
20
21pub struct NetworkProviderReqwest {
22 has_file_write_access: bool,
23}
24
25impl NetworkProviderReqwest {
26 pub fn new() -> Self {
27 let has_file_write_access = match tempfile::tempfile() {
28 Ok(_) => true,
29 Err(_) => false,
30 };
31
32 Self {
33 has_file_write_access,
34 }
35 }
36}
37
38#[async_trait]
39impl NetworkProvider for NetworkProviderReqwest {
40 async fn send(&self, method: &HttpMethod, args: &RequestArgs) -> Response {
41 if let Some(is_shutdown) = &args.is_shutdown {
42 if is_shutdown.load(std::sync::atomic::Ordering::SeqCst) {
43 return Response {
44 status_code: None,
45 data: None,
46 error: Some("Request was shutdown".to_string()),
47 headers: None,
48 };
49 }
50 }
51
52 let request = self.build_request(method, args);
53
54 let mut error = None;
55 let mut status_code = None;
56 let mut data = None;
57 let mut headers = None;
58
59 match request.send().await {
60 Ok(response) => {
61 status_code = Some(response.status().as_u16());
62 headers = get_response_headers(&response);
63
64 let data_result =
65 if !self.has_file_write_access || args.disable_file_streaming == Some(true) {
66 Self::write_response_to_in_memory_buffer(response).await
67 } else {
68 Self::write_response_to_temp_file(response).await
69 };
70
71 match data_result {
72 Ok(response_data) => data = Some(response_data),
73 Err(e) => {
74 error = Some(e.to_string());
75 }
76 }
77 }
78 Err(e) => {
79 let error_message = get_error_message(e);
80 error = Some(error_message);
81 }
82 }
83
84 Response {
85 status_code,
86 data,
87 error,
88 headers,
89 }
90 }
91}
92
93impl NetworkProviderReqwest {
94 fn build_request(
95 &self,
96 method: &HttpMethod,
97 request_args: &RequestArgs,
98 ) -> reqwest::RequestBuilder {
99 let method_actual = match method {
100 HttpMethod::GET => Method::GET,
101 HttpMethod::POST => Method::POST,
102 };
103 let is_post = method_actual == Method::POST;
104
105 let mut client_builder = reqwest::Client::builder();
106
107 if let Some(proxy_config) = request_args.proxy_config.as_ref() {
109 client_builder = Self::configure_proxy(client_builder, proxy_config);
110 }
111
112 let client = client_builder.build().unwrap_or_else(|e| {
113 log_e!(TAG, "Failed to build reqwest client with proxy config: {}. Falling back to default client.", e);
114 reqwest::Client::new()
115 });
116
117 let mut request = client.request(method_actual, &request_args.url);
118
119 let timeout_duration = match request_args.timeout_ms > 0 {
120 true => Duration::from_millis(request_args.timeout_ms),
121 false => Duration::from_secs(10),
122 };
123 request = request.timeout(timeout_duration);
124
125 if let Some(headers) = &request_args.headers {
126 for (key, value) in headers {
127 request = request.header(key, value);
128 }
129 }
130
131 if let Some(params) = &request_args.query_params {
132 request = request.query(params);
133 }
134
135 if is_post {
136 let bytes = match &request_args.body {
137 Some(b) => b.clone(),
138 None => vec![],
139 };
140 let byte_len = bytes.len();
141
142 request = request.body(bytes);
143 request = request.header("Content-Length", byte_len.to_string());
144 }
145
146 request
147 }
148
149 fn configure_proxy(
150 client_builder: reqwest::ClientBuilder,
151 proxy_config: &ProxyConfig,
152 ) -> reqwest::ClientBuilder {
153 let (Some(host), Some(port)) = (&proxy_config.proxy_host, &proxy_config.proxy_port) else {
154 return client_builder;
155 };
156
157 let proxy_url = format!(
158 "{}://{}:{}",
159 proxy_config.proxy_protocol.as_deref().unwrap_or("http"),
160 host,
161 port
162 );
163
164 let Ok(proxy) = reqwest::Proxy::all(&proxy_url) else {
165 log_w!(TAG, "Failed to create proxy for URL: {}", proxy_url);
166 return client_builder;
167 };
168
169 let Some(auth) = &proxy_config.proxy_auth else {
170 return client_builder.proxy(proxy);
171 };
172
173 let Some((username, password)) = auth.split_once(':') else {
174 log_w!(
175 TAG,
176 "Invalid proxy auth format. Expected 'username:password'"
177 );
178 return client_builder.proxy(proxy);
179 };
180
181 client_builder.proxy(proxy.basic_auth(username, password))
182 }
183
184 async fn write_response_to_temp_file(
185 response: reqwest::Response,
186 ) -> Result<ResponseData, StatsigErr> {
187 let mut response = response;
188 let mut temp_file = tempfile::spooled_tempfile(1024 * 1024 * 2); while let Some(item) = response
191 .chunk()
192 .await
193 .map_err(|e| StatsigErr::FileError(e.to_string()))?
194 {
195 temp_file
196 .write_all(&item)
197 .map_err(|e| StatsigErr::FileError(e.to_string()))?;
198 }
199
200 temp_file
201 .seek(SeekFrom::Start(0))
202 .map_err(|e| StatsigErr::FileError(e.to_string()))?;
203
204 let reader = BufReader::new(temp_file);
205 Ok(ResponseData::from_stream(Box::new(reader)))
206 }
207
208 async fn write_response_to_in_memory_buffer(
209 response: reqwest::Response,
210 ) -> Result<ResponseData, StatsigErr> {
211 let bytes = response
212 .bytes()
213 .await
214 .map_err(|e| StatsigErr::SerializationError(e.to_string()))?;
215
216 Ok(ResponseData::from_bytes(bytes.to_vec()))
217 }
218}
219
220fn get_error_message(error: reqwest::Error) -> String {
221 let mut error_message = error.to_string();
222
223 if let Some(url_error) = error.url() {
224 error_message.push_str(&format!(". URL: {}", url_error));
225 }
226
227 if let Some(status_error) = error.status() {
228 error_message.push_str(&format!(". Status: {}", status_error));
229 }
230
231 error_message
232}
233
234fn get_response_headers(response: &reqwest::Response) -> Option<HashMap<String, String>> {
235 let headers = response.headers();
236 if headers.is_empty() {
237 return None;
238 }
239
240 let mut headers_map = HashMap::new();
241 for (key, value) in headers {
242 headers_map.insert(key.to_string(), value.to_str().unwrap_or("").to_string());
243 }
244
245 Some(headers_map)
246}