1use chrono::Utc;
2
3use super::providers::get_network_provider;
4use super::{HttpMethod, NetworkProvider, RequestArgs};
5use crate::observability::ops_stats::{OpsStatsForInstance, OPS_STATS};
6use crate::observability::ErrorBoundaryEvent;
7use crate::sdk_diagnostics::marker::{ActionType, Marker, StepType};
8use crate::{log_d, log_error_to_statsig_and_console, log_i, log_w};
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::{Arc, Weak};
12use std::time::Duration;
13
14const RETRY_CODES: [u16; 8] = [408, 500, 502, 503, 504, 522, 524, 599];
15const SHUTDOWN_ERROR: &str = "Request was aborted because the client is shutting down";
16
17#[derive(PartialEq, Debug)]
18pub enum NetworkError {
19 ShutdownError,
20 RequestFailed,
21 RetriesExhausted,
22 SerializationError,
23}
24const TAG: &str = stringify!(NetworkClient);
25
26pub struct NetworkClient {
27 headers: HashMap<String, String>,
28 is_shutdown: Arc<AtomicBool>,
29 ops_stats: Arc<OpsStatsForInstance>,
30 net_provider: Weak<dyn NetworkProvider>,
31}
32
33impl NetworkClient {
34 #[must_use]
35 pub fn new(sdk_key: &str, headers: Option<HashMap<String, String>>) -> Self {
36 let net_provider = get_network_provider();
37
38 NetworkClient {
39 headers: headers.unwrap_or_default(),
40 is_shutdown: Arc::new(AtomicBool::new(false)),
41 net_provider,
42 ops_stats: OPS_STATS.get_for_instance(sdk_key),
43 }
44 }
45
46 pub fn shutdown(&self) {
47 self.is_shutdown.store(true, Ordering::SeqCst);
48 }
49
50 pub async fn get(&self, request_args: RequestArgs) -> Result<String, NetworkError> {
51 self.make_request(HttpMethod::GET, request_args).await
52 }
53
54 pub async fn post(
55 &self,
56 mut request_args: RequestArgs,
57 body: Option<Vec<u8>>,
58 ) -> Result<String, NetworkError> {
59 request_args.body = body;
60 self.make_request(HttpMethod::POST, request_args).await
61 }
62
63 async fn make_request(
64 &self,
65 method: HttpMethod,
66 mut request_args: RequestArgs,
67 ) -> Result<String, NetworkError> {
68 let is_shutdown = if let Some(is_shutdown) = &request_args.is_shutdown {
69 is_shutdown.clone()
70 } else {
71 self.is_shutdown.clone()
72 };
73
74 request_args.populate_headers(self.headers.clone());
75
76 let mut merged_headers = request_args.headers.unwrap_or_default();
77 if !self.headers.is_empty() {
78 merged_headers.extend(self.headers.clone());
79 }
80 merged_headers.insert(
81 "STATSIG-CLIENT-TIME".into(),
82 Utc::now().timestamp_millis().to_string(),
83 );
84 request_args.headers = Some(merged_headers);
85
86 let mut attempt = 0;
87
88 loop {
89 if let Some(key) = request_args.diagnostics_key {
90 self.ops_stats.add_marker(
91 Marker::new(key, ActionType::Start, Some(StepType::NetworkRequest))
92 .with_attempt(attempt)
93 .with_url(request_args.url.clone()),
94 None,
95 );
96 }
97 if is_shutdown.load(Ordering::SeqCst) {
98 log_i!(TAG, "{}", SHUTDOWN_ERROR);
99 return Err(NetworkError::ShutdownError);
100 }
101
102 let response = match self.net_provider.upgrade() {
103 Some(net_provider) => net_provider.send(&method, &request_args).await,
104 None => return Err(NetworkError::RequestFailed),
105 };
106
107 log_d!(
108 TAG,
109 "Response ({}): {}",
110 &request_args.url,
111 response.status_code
112 );
113
114 let status = response.status_code;
115 let sdk_region_str = response
116 .headers
117 .as_ref()
118 .and_then(|h| h.get("x-statsig-region"));
119 let success = (200..300).contains(&status);
120
121 let error_message = response
122 .error
123 .unwrap_or_else(|| get_error_message_for_status(status));
124
125 if let Some(key) = request_args.diagnostics_key {
126 let mut end_marker =
127 Marker::new(key, ActionType::End, Some(StepType::NetworkRequest))
128 .with_attempt(attempt)
129 .with_url(request_args.url.clone())
130 .with_status_code(status)
131 .with_is_success(success)
132 .with_sdk_region(sdk_region_str.map(|s| s.to_owned()));
133
134 let error_map = if !error_message.is_empty() {
135 let mut map = HashMap::new();
136 map.insert("name".to_string(), "NetworkError".to_string());
137 map.insert("message".to_string(), error_message.clone());
138 map.insert("code".to_string(), status.to_string());
139 Some(map)
140 } else {
141 None
142 };
143
144 if let Some(error_map) = error_map {
145 end_marker = end_marker.with_error(error_map);
146 }
147
148 self.ops_stats.add_marker(end_marker, None);
149 }
150
151 if success {
152 return get_data_as_string(response.data);
153 }
154
155 if !RETRY_CODES.contains(&status) {
156 log_error_to_statsig_and_console!(
157 &self.ops_stats,
158 TAG,
159 "status:{} message:{}",
160 status,
161 error_message
162 );
163 return Err(NetworkError::RequestFailed);
164 }
165
166 if attempt >= request_args.retries {
167 log_error_to_statsig_and_console!(
168 &self.ops_stats,
169 TAG,
170 "Network error, retries exhausted: {} {}",
171 status,
172 error_message
173 );
174 return Err(NetworkError::RetriesExhausted);
175 }
176
177 attempt += 1;
178 let backoff_ms = 2_u64.pow(attempt) * 100;
179
180 log_w!(
181 TAG, "Network request failed with status code {} (attempt {}), will retry after {}ms...\n{}",
182 status,
183 attempt,
184 backoff_ms,
185 error_message
186 );
187
188 tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
189 }
190 }
191}
192
193fn get_error_message_for_status(status: u16) -> String {
194 if (200..300).contains(&status) {
195 return String::new();
196 }
197
198 match status {
199 400 => "Bad Request".to_string(),
200 401 => "Unauthorized".to_string(),
201 403 => "Forbidden".to_string(),
202 404 => "Not Found".to_string(),
203 405 => "Method Not Allowed".to_string(),
204 406 => "Not Acceptable".to_string(),
205 408 => "Request Timeout".to_string(),
206 500 => "Internal Server Error".to_string(),
207 502 => "Bad Gateway".to_string(),
208 503 => "Service Unavailable".to_string(),
209 504 => "Gateway Timeout".to_string(),
210 0 => "Unknown Error".to_string(),
211 _ => format!("HTTP Error {status}"),
212 }
213}
214
215fn get_data_as_string(data: Option<Vec<u8>>) -> Result<String, NetworkError> {
216 match data {
218 Some(data) => Ok(String::from_utf8(data).map_err(|_| NetworkError::SerializationError)?),
219 None => Err(NetworkError::RequestFailed),
220 }
221}