1use chrono::Utc;
2use serde::Serialize;
3
4use super::providers::get_network_provider;
5use super::{HttpMethod, NetworkProvider, RequestArgs, Response};
6use crate::observability::ops_stats::{OpsStatsForInstance, OPS_STATS};
7use crate::observability::ErrorBoundaryEvent;
8use crate::sdk_diagnostics::marker::{ActionType, Marker, StepType};
9use crate::{log_d, log_i, log_w, StatsigErr};
10use std::collections::HashMap;
11use std::fmt;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::{Arc, Weak};
14use std::time::Duration;
15
16const RETRY_CODES: [u16; 8] = [408, 500, 502, 503, 504, 522, 524, 599];
17const SHUTDOWN_ERROR: &str = "Request was aborted because the client is shutting down";
18
19#[derive(PartialEq, Debug, Clone, Serialize)]
20pub enum NetworkError {
21 ShutdownError,
22 RequestFailed,
23 RetriesExhausted,
24 SerializationError(String),
25 DisableNetworkOn,
26 RequestNotRetryable,
27}
28
29impl fmt::Display for NetworkError {
30 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31 match self {
32 NetworkError::ShutdownError => write!(f, "ShutdownError"),
33 NetworkError::RequestFailed => write!(f, "RequestFailed"),
34 NetworkError::RetriesExhausted => write!(f, "RetriesExhausted"),
35 NetworkError::SerializationError(s) => write!(f, "SerializationError: {s}"),
36 NetworkError::DisableNetworkOn => write!(f, "DisableNetworkOn"),
37 NetworkError::RequestNotRetryable => write!(f, "RequestNotRetryable"),
38 }
39 }
40}
41
42const TAG: &str = stringify!(NetworkClient);
43
44pub struct NetworkClient {
45 headers: HashMap<String, String>,
46 is_shutdown: Arc<AtomicBool>,
47 ops_stats: Arc<OpsStatsForInstance>,
48 net_provider: Weak<dyn NetworkProvider>,
49 disable_network: bool,
50 silent_on_network_failure: bool,
51}
52
53impl NetworkClient {
54 #[must_use]
55 pub fn new(
56 sdk_key: &str,
57 headers: Option<HashMap<String, String>>,
58 disable_network: Option<bool>,
59 ) -> Self {
60 let net_provider = get_network_provider();
61
62 NetworkClient {
63 headers: headers.unwrap_or_default(),
64 is_shutdown: Arc::new(AtomicBool::new(false)),
65 net_provider,
66 ops_stats: OPS_STATS.get_for_instance(sdk_key),
67 disable_network: disable_network.unwrap_or_default(),
68 silent_on_network_failure: false,
69 }
70 }
71
72 pub fn shutdown(&self) {
73 self.is_shutdown.store(true, Ordering::SeqCst);
74 }
75
76 pub async fn get(&self, request_args: RequestArgs) -> Result<Response, NetworkError> {
77 self.make_request(HttpMethod::GET, request_args).await
78 }
79
80 pub async fn post(
81 &self,
82 mut request_args: RequestArgs,
83 body: Option<Vec<u8>>,
84 ) -> Result<Response, NetworkError> {
85 request_args.body = body;
86 self.make_request(HttpMethod::POST, request_args).await
87 }
88
89 async fn make_request(
90 &self,
91 method: HttpMethod,
92 mut request_args: RequestArgs,
93 ) -> Result<Response, NetworkError> {
94 let is_shutdown = if let Some(is_shutdown) = &request_args.is_shutdown {
95 is_shutdown.clone()
96 } else {
97 self.is_shutdown.clone()
98 };
99
100 if self.disable_network {
101 log_d!(TAG, "Network is disabled, not making requests");
102 return Err(NetworkError::DisableNetworkOn);
103 }
104
105 request_args.populate_headers(self.headers.clone());
106
107 let mut merged_headers = request_args.headers.unwrap_or_default();
108 if !self.headers.is_empty() {
109 merged_headers.extend(self.headers.clone());
110 }
111 merged_headers.insert(
112 "STATSIG-CLIENT-TIME".into(),
113 Utc::now().timestamp_millis().to_string(),
114 );
115 request_args.headers = Some(merged_headers);
116
117 let mut attempt = 0;
118
119 loop {
120 if let Some(key) = request_args.diagnostics_key {
121 self.ops_stats.add_marker(
122 Marker::new(key, ActionType::Start, Some(StepType::NetworkRequest))
123 .with_attempt(attempt)
124 .with_url(request_args.url.clone()),
125 None,
126 );
127 }
128 if is_shutdown.load(Ordering::SeqCst) {
129 log_i!(TAG, "{}", SHUTDOWN_ERROR);
130 return Err(NetworkError::ShutdownError);
131 }
132
133 let response = match self.net_provider.upgrade() {
134 Some(net_provider) => net_provider.send(&method, &request_args).await,
135 None => return Err(NetworkError::RequestFailed),
136 };
137
138 log_d!(
139 TAG,
140 "Response ({}): {}",
141 &request_args.url,
142 response.status_code
143 );
144
145 let status = response.status_code;
146 let sdk_region_str = response
147 .headers
148 .as_ref()
149 .and_then(|h| h.get("x-statsig-region"));
150 let success = (200..300).contains(&status);
151
152 let error_message = response
153 .error
154 .clone()
155 .unwrap_or_else(|| get_error_message_for_status(status));
156
157 if let Some(key) = request_args.diagnostics_key {
158 let mut end_marker =
159 Marker::new(key, ActionType::End, Some(StepType::NetworkRequest))
160 .with_attempt(attempt)
161 .with_url(request_args.url.clone())
162 .with_status_code(status)
163 .with_is_success(success)
164 .with_sdk_region(sdk_region_str.map(|s| s.to_owned()));
165
166 let error_map = if !error_message.is_empty() {
167 let mut map = HashMap::new();
168 map.insert("name".to_string(), "NetworkError".to_string());
169 map.insert("message".to_string(), error_message.clone());
170 map.insert("code".to_string(), status.to_string());
171 Some(map)
172 } else {
173 None
174 };
175
176 if let Some(error_map) = error_map {
177 end_marker = end_marker.with_error(error_map);
178 }
179
180 self.ops_stats.add_marker(end_marker, None);
181 }
182
183 if success {
184 return Ok(response);
185 }
186
187 if !RETRY_CODES.contains(&status) {
188 let msg = format!("Network error, not retrying: {} {}", status, error_message);
189 self.log_warning(
190 StatsigErr::NetworkError(NetworkError::RequestNotRetryable, Some(msg)),
191 &request_args,
192 );
193 return Err(NetworkError::RequestNotRetryable);
194 }
195
196 if attempt >= request_args.retries {
197 let msg = format!(
198 "Network error, retries exhausted: {} {}",
199 status, error_message
200 );
201 self.log_warning(
202 StatsigErr::NetworkError(NetworkError::RetriesExhausted, Some(msg)),
203 &request_args,
204 );
205 return Err(NetworkError::RetriesExhausted);
206 }
207
208 attempt += 1;
209 let backoff_ms = 2_u64.pow(attempt) * 100;
210
211 log_w!(
212 TAG, "Network request failed with status code {} (attempt {}), will retry after {}ms...\n{}",
213 status,
214 attempt,
215 backoff_ms,
216 error_message
217 );
218
219 tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
220 }
221 }
222
223 pub fn mute_network_error_log(mut self) -> Self {
224 self.silent_on_network_failure = true;
225 self
226 }
227
228 fn log_warning(&self, error: StatsigErr, args: &RequestArgs) {
229 log_w!(TAG, "{}", error);
230 if !self.silent_on_network_failure {
231 let dedupe_key = format!("{:?}", args.diagnostics_key);
232 self.ops_stats.log_error(ErrorBoundaryEvent {
233 tag: TAG.to_string(),
234 bypass_dedupe: false,
235 info: error,
236 dedupe_key: Some(dedupe_key),
237 extra: None,
238 });
239 }
240 }
241}
242
243fn get_error_message_for_status(status: u16) -> String {
244 if (200..300).contains(&status) {
245 return String::new();
246 }
247
248 match status {
249 400 => "Bad Request".to_string(),
250 401 => "Unauthorized".to_string(),
251 403 => "Forbidden".to_string(),
252 404 => "Not Found".to_string(),
253 405 => "Method Not Allowed".to_string(),
254 406 => "Not Acceptable".to_string(),
255 408 => "Request Timeout".to_string(),
256 500 => "Internal Server Error".to_string(),
257 502 => "Bad Gateway".to_string(),
258 503 => "Service Unavailable".to_string(),
259 504 => "Gateway Timeout".to_string(),
260 0 => "Unknown Error".to_string(),
261 _ => format!("HTTP Error {status}"),
262 }
263}