1use chrono::Utc;
2
3use super::network_error::NetworkError;
4use super::providers::get_network_provider;
5use super::{HttpMethod, NetworkProvider, RequestArgs, Response};
6use crate::networking::proxy_config::ProxyConfig;
7use crate::observability::ops_stats::{OpsStatsForInstance, OPS_STATS};
8use crate::observability::ErrorBoundaryEvent;
9use crate::sdk_diagnostics::marker::{ActionType, Marker, StepType};
10use crate::{log_d, log_i, log_w, StatsigOptions};
11use std::collections::HashMap;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::{Arc, Weak};
14use std::time::Duration;
15
16const NON_RETRY_CODES: [u16; 6] = [
17 400, 403, 413, 405, 429, 501, ];
24const SHUTDOWN_ERROR: &str = "Request was aborted because the client is shutting down";
25
26const TAG: &str = stringify!(NetworkClient);
27
28pub struct NetworkClient {
29 headers: HashMap<String, String>,
30 is_shutdown: Arc<AtomicBool>,
31 ops_stats: Arc<OpsStatsForInstance>,
32 net_provider: Weak<dyn NetworkProvider>,
33 disable_network: bool,
34 proxy_config: Option<ProxyConfig>,
35 silent_on_network_failure: bool,
36 disable_file_streaming: bool,
37}
38
39impl NetworkClient {
40 #[must_use]
41 pub fn new(
42 sdk_key: &str,
43 headers: Option<HashMap<String, String>>,
44 options: Option<&StatsigOptions>,
45 ) -> Self {
46 let net_provider = get_network_provider();
47 let (disable_network, proxy_config) = options
48 .map(|opts| {
49 (
50 opts.disable_network.unwrap_or(false),
51 opts.proxy_config.clone(),
52 )
53 })
54 .unwrap_or((false, None));
55
56 NetworkClient {
57 headers: headers.unwrap_or_default(),
58 is_shutdown: Arc::new(AtomicBool::new(false)),
59 net_provider,
60 ops_stats: OPS_STATS.get_for_instance(sdk_key),
61 disable_network,
62 proxy_config,
63 silent_on_network_failure: false,
64 disable_file_streaming: options
65 .map(|opts| opts.disable_disk_access.unwrap_or(false))
66 .unwrap_or(false),
67 }
68 }
69
70 pub fn shutdown(&self) {
71 self.is_shutdown.store(true, Ordering::SeqCst);
72 }
73
74 pub async fn get(&self, request_args: RequestArgs) -> Result<Response, NetworkError> {
75 self.make_request(HttpMethod::GET, request_args).await
76 }
77
78 pub async fn post(
79 &self,
80 mut request_args: RequestArgs,
81 body: Option<Vec<u8>>,
82 ) -> Result<Response, NetworkError> {
83 request_args.body = body;
84 self.make_request(HttpMethod::POST, request_args).await
85 }
86
87 async fn make_request(
88 &self,
89 method: HttpMethod,
90 mut request_args: RequestArgs,
91 ) -> Result<Response, NetworkError> {
92 let is_shutdown = if let Some(is_shutdown) = &request_args.is_shutdown {
93 is_shutdown.clone()
94 } else {
95 self.is_shutdown.clone()
96 };
97
98 if self.disable_network {
99 log_d!(TAG, "Network is disabled, not making requests");
100 return Err(NetworkError::DisableNetworkOn(request_args.url));
101 }
102
103 request_args.populate_headers(self.headers.clone());
104
105 if request_args.disable_file_streaming.is_none() {
106 request_args.disable_file_streaming = Some(self.disable_file_streaming);
107 }
108
109 let mut merged_headers = request_args.headers.unwrap_or_default();
110 if !self.headers.is_empty() {
111 merged_headers.extend(self.headers.clone());
112 }
113 merged_headers.insert(
114 "STATSIG-CLIENT-TIME".into(),
115 Utc::now().timestamp_millis().to_string(),
116 );
117 request_args.headers = Some(merged_headers);
118
119 if let Some(proxy_config) = &self.proxy_config {
121 request_args.proxy_config = Some(proxy_config.clone());
122 }
123 let supports_proto = request_args
124 .headers
125 .as_ref()
126 .and_then(|headers| headers.get("statsig-supports-proto"))
127 .map(|value| value.eq_ignore_ascii_case("true"));
128 let mut attempt = 0;
129
130 loop {
131 if let Some(key) = request_args.diagnostics_key {
132 self.ops_stats.add_marker(
133 Marker::new(key, ActionType::Start, Some(StepType::NetworkRequest))
134 .with_attempt(attempt)
135 .with_url(request_args.url.clone())
136 .with_request_supports_proto(supports_proto),
137 None,
138 );
139 }
140 if is_shutdown.load(Ordering::SeqCst) {
141 log_i!(TAG, "{}", SHUTDOWN_ERROR);
142 return Err(NetworkError::ShutdownError(request_args.url));
143 }
144
145 let mut response = match self.net_provider.upgrade() {
146 Some(net_provider) => net_provider.send(&method, &request_args).await,
147 None => {
148 return Err(NetworkError::RequestFailed(
149 request_args.url,
150 None,
151 "Failed to get a NetworkProvider instance".to_string(),
152 ));
153 }
154 };
155
156 let status = response.status_code;
157 let error_message = response
158 .error
159 .clone()
160 .unwrap_or_else(|| get_error_message_for_status(status, response.data.as_mut()));
161
162 let content_type = response
163 .data
164 .as_ref()
165 .and_then(|data| data.get_header_ref("content-type"));
166
167 log_d!(
168 TAG,
169 "Response url({}) status({:?}) content-type({:?})",
170 &request_args.url,
171 response.status_code,
172 content_type
173 );
174
175 let sdk_region_str = response
176 .data
177 .as_ref()
178 .and_then(|data| data.get_header_ref("x-statsig-region").cloned());
179 let success = (200..300).contains(&status.unwrap_or(0));
180
181 if let Some(key) = request_args.diagnostics_key {
182 let mut end_marker =
183 Marker::new(key, ActionType::End, Some(StepType::NetworkRequest))
184 .with_attempt(attempt)
185 .with_url(request_args.url.clone())
186 .with_is_success(success)
187 .with_content_type(content_type.cloned())
188 .with_request_supports_proto(supports_proto)
189 .with_sdk_region(sdk_region_str.map(|s| s.to_owned()));
190
191 if let Some(status_code) = status {
192 end_marker = end_marker.with_status_code(status_code);
193 }
194
195 let error_map = if !error_message.is_empty() {
196 let mut map = HashMap::new();
197 map.insert("name".to_string(), "NetworkError".to_string());
198 map.insert("message".to_string(), error_message.clone());
199 let status_string = match status {
200 Some(code) => code.to_string(),
201 None => "None".to_string(),
202 };
203 map.insert("code".to_string(), status_string);
204 Some(map)
205 } else {
206 None
207 };
208
209 if let Some(error_map) = error_map {
210 end_marker = end_marker.with_error(error_map);
211 }
212
213 self.ops_stats.add_marker(end_marker, None);
214 }
215
216 if success {
217 return Ok(response);
218 }
219
220 if NON_RETRY_CODES.contains(&status.unwrap_or(0)) {
221 let error = NetworkError::RequestNotRetryable(
222 request_args.url.clone(),
223 status,
224 error_message,
225 );
226 self.log_warning(&error, &request_args);
227 return Err(error);
228 }
229
230 if attempt >= request_args.retries {
231 let error = NetworkError::RetriesExhausted(
232 request_args.url.clone(),
233 status,
234 attempt + 1,
235 error_message,
236 );
237 self.log_warning(&error, &request_args);
238 return Err(error);
239 }
240
241 attempt += 1;
242 let backoff_ms = 2_u64.pow(attempt) * 100;
243
244 log_i!(
245 TAG, "Network request failed with status code {} (attempt {}/{}), will retry after {}ms...\n{}",
246 status.map_or("unknown".to_string(), |s| s.to_string()),
247 attempt,
248 request_args.retries + 1,
249 backoff_ms,
250 error_message
251 );
252
253 tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
254 }
255 }
256
257 pub fn mute_network_error_log(mut self) -> Self {
258 self.silent_on_network_failure = true;
259 self
260 }
261
262 fn log_warning(&self, error: &NetworkError, args: &RequestArgs) {
263 let exception = error.name();
264
265 log_w!(TAG, "{}", error);
266 if !self.silent_on_network_failure {
267 let dedupe_key = format!("{:?}", args.diagnostics_key);
268 self.ops_stats.log_error(ErrorBoundaryEvent {
269 tag: TAG.to_string(),
270 exception: exception.to_string(),
271 bypass_dedupe: false,
272 info: serde_json::to_string(error).unwrap_or_default(),
273 dedupe_key: Some(dedupe_key),
274 extra: None,
275 });
276 }
277 }
278}
279
280fn get_error_message_for_status(
281 status: Option<u16>,
282 data: Option<&mut super::ResponseData>,
283) -> String {
284 if (200..300).contains(&status.unwrap_or(0)) {
285 return String::new();
286 }
287
288 let mut message = String::new();
289 if let Some(data) = data {
290 let lossy_str = data.read_to_string().unwrap_or_default();
291 if lossy_str.is_ascii() {
292 message = lossy_str.to_string();
293 }
294 }
295
296 let status_value = match status {
297 Some(code) => code,
298 None => return format!("HTTP Error None: {message}"),
299 };
300
301 let generic_message = match status_value {
302 400 => "Bad Request",
303 401 => "Unauthorized",
304 403 => "Forbidden",
305 404 => "Not Found",
306 405 => "Method Not Allowed",
307 406 => "Not Acceptable",
308 408 => "Request Timeout",
309 500 => "Internal Server Error",
310 502 => "Bad Gateway",
311 503 => "Service Unavailable",
312 504 => "Gateway Timeout",
313 0 => "Unknown Error",
314 _ => return format!("HTTP Error {status_value}: {message}"),
315 };
316
317 if message.is_empty() {
318 return generic_message.to_string();
319 }
320
321 format!("{generic_message}: {message}")
322}