statsig_rust/networking/providers/
net_provider_reqwest.rs1use std::collections::HashMap;
2use std::io::{BufReader, Seek, SeekFrom, Write};
3use std::time::Duration;
4
5use async_trait::async_trait;
6
7use crate::{
8 log_e, log_w,
9 networking::{
10 http_types::{HttpMethod, RequestArgs, Response, ResponseData},
11 NetworkProvider,
12 },
13 StatsigErr,
14};
15
16use crate::networking::proxy_config::ProxyConfig;
17use reqwest::Method;
18
19const TAG: &str = "NetworkProviderReqwest";
20
21pub struct NetworkProviderReqwest {}
22
23#[async_trait]
24impl NetworkProvider for NetworkProviderReqwest {
25 async fn send(&self, method: &HttpMethod, args: &RequestArgs) -> Response {
26 if let Some(is_shutdown) = &args.is_shutdown {
27 if is_shutdown.load(std::sync::atomic::Ordering::SeqCst) {
28 return Response {
29 status_code: None,
30 data: None,
31 error: Some("Request was shutdown".to_string()),
32 headers: None,
33 };
34 }
35 }
36
37 let request = self.build_request(method, args);
38
39 let mut error = None;
40 let mut status_code = None;
41 let mut data = None;
42 let mut headers = None;
43
44 match request.send().await {
45 Ok(response) => {
46 status_code = Some(response.status().as_u16());
47 headers = get_response_headers(&response);
48
49 match Self::write_response_to_temp_file(response).await {
50 Ok(response_data) => data = Some(response_data),
51 Err(e) => {
52 error = Some(e.to_string());
53 }
54 }
55 }
56 Err(e) => {
57 let error_message = get_error_message(e);
58 error = Some(error_message);
59 }
60 }
61
62 Response {
63 status_code,
64 data,
65 error,
66 headers,
67 }
68 }
69}
70
71impl NetworkProviderReqwest {
72 fn build_request(
73 &self,
74 method: &HttpMethod,
75 request_args: &RequestArgs,
76 ) -> reqwest::RequestBuilder {
77 let method_actual = match method {
78 HttpMethod::GET => Method::GET,
79 HttpMethod::POST => Method::POST,
80 };
81 let is_post = method_actual == Method::POST;
82
83 let mut client_builder = reqwest::Client::builder();
84
85 if let Some(proxy_config) = request_args.proxy_config.as_ref() {
87 client_builder = Self::configure_proxy(client_builder, proxy_config);
88 }
89
90 let client = client_builder.build().unwrap_or_else(|e| {
91 log_e!(TAG, "Failed to build reqwest client with proxy config: {}. Falling back to default client.", e);
92 reqwest::Client::new()
93 });
94
95 let mut request = client.request(method_actual, &request_args.url);
96
97 let timeout_duration = match request_args.timeout_ms > 0 {
98 true => Duration::from_millis(request_args.timeout_ms),
99 false => Duration::from_secs(10),
100 };
101 request = request.timeout(timeout_duration);
102
103 if let Some(headers) = &request_args.headers {
104 for (key, value) in headers {
105 request = request.header(key, value);
106 }
107 }
108
109 if let Some(params) = &request_args.query_params {
110 request = request.query(params);
111 }
112
113 if is_post {
114 let bytes = match &request_args.body {
115 Some(b) => b.clone(),
116 None => vec![],
117 };
118 let byte_len = bytes.len();
119
120 request = request.body(bytes);
121 request = request.header("Content-Length", byte_len.to_string());
122 }
123
124 request
125 }
126
127 fn configure_proxy(
128 client_builder: reqwest::ClientBuilder,
129 proxy_config: &ProxyConfig,
130 ) -> reqwest::ClientBuilder {
131 let (Some(host), Some(port)) = (&proxy_config.proxy_host, &proxy_config.proxy_port) else {
132 return client_builder;
133 };
134
135 let proxy_url = format!(
136 "{}://{}:{}",
137 proxy_config.proxy_protocol.as_deref().unwrap_or("http"),
138 host,
139 port
140 );
141
142 let Ok(proxy) = reqwest::Proxy::all(&proxy_url) else {
143 log_w!(TAG, "Failed to create proxy for URL: {}", proxy_url);
144 return client_builder;
145 };
146
147 let Some(auth) = &proxy_config.proxy_auth else {
148 return client_builder.proxy(proxy);
149 };
150
151 let Some((username, password)) = auth.split_once(':') else {
152 log_w!(
153 TAG,
154 "Invalid proxy auth format. Expected 'username:password'"
155 );
156 return client_builder.proxy(proxy);
157 };
158
159 client_builder.proxy(proxy.basic_auth(username, password))
160 }
161
162 async fn write_response_to_temp_file(
163 response: reqwest::Response,
164 ) -> Result<ResponseData, StatsigErr> {
165 let mut response = response;
166 let mut temp_file = tempfile::spooled_tempfile(1024 * 1024 * 2); while let Some(item) = response
169 .chunk()
170 .await
171 .map_err(|e| StatsigErr::FileError(e.to_string()))?
172 {
173 temp_file
174 .write_all(&item)
175 .map_err(|e| StatsigErr::FileError(e.to_string()))?;
176 }
177
178 temp_file
179 .seek(SeekFrom::Start(0))
180 .map_err(|e| StatsigErr::FileError(e.to_string()))?;
181
182 let reader = BufReader::new(temp_file);
183 Ok(ResponseData::from_stream(Box::new(reader)))
184 }
185}
186
187fn get_error_message(error: reqwest::Error) -> String {
188 let mut error_message = error.to_string();
189
190 if let Some(url_error) = error.url() {
191 error_message.push_str(&format!(". URL: {}", url_error));
192 }
193
194 if let Some(status_error) = error.status() {
195 error_message.push_str(&format!(". Status: {}", status_error));
196 }
197
198 error_message
199}
200
201fn get_response_headers(response: &reqwest::Response) -> Option<HashMap<String, String>> {
202 let headers = response.headers();
203 if headers.is_empty() {
204 return None;
205 }
206
207 let mut headers_map = HashMap::new();
208 for (key, value) in headers {
209 headers_map.insert(key.to_string(), value.to_str().unwrap_or("").to_string());
210 }
211
212 Some(headers_map)
213}