statsig_rust/networking/
network_client.rs

1use chrono::Utc;
2
3use super::providers::get_network_provider;
4use super::{HttpMethod, NetworkProvider, RequestArgs, Response};
5use crate::observability::ops_stats::{OpsStatsForInstance, OPS_STATS};
6use crate::observability::ErrorBoundaryEvent;
7use crate::sdk_diagnostics::marker::{ActionType, Marker, StepType};
8use crate::{log_d, log_i, log_w};
9use std::collections::HashMap;
10use std::fmt;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::{Arc, Weak};
13use std::time::Duration;
14
15const RETRY_CODES: [u16; 8] = [408, 500, 502, 503, 504, 522, 524, 599];
16const SHUTDOWN_ERROR: &str = "Request was aborted because the client is shutting down";
17
18#[derive(PartialEq, Debug, Clone)]
19pub enum NetworkError {
20    ShutdownError,
21    RequestFailed,
22    RetriesExhausted,
23    SerializationError(String),
24    DisableNetworkOn,
25}
26
27impl fmt::Display for NetworkError {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        match self {
30            NetworkError::ShutdownError => write!(f, "ShutdownError"),
31            NetworkError::RequestFailed => write!(f, "RequestFailed"),
32            NetworkError::RetriesExhausted => write!(f, "RetriesExhausted"),
33            NetworkError::SerializationError(s) => write!(f, "SerializationError: {s}"),
34            NetworkError::DisableNetworkOn => write!(f, "DisableNetworkOn"),
35        }
36    }
37}
38
39const TAG: &str = stringify!(NetworkClient);
40
41pub struct NetworkClient {
42    headers: HashMap<String, String>,
43    is_shutdown: Arc<AtomicBool>,
44    ops_stats: Arc<OpsStatsForInstance>,
45    net_provider: Weak<dyn NetworkProvider>,
46    disable_network: bool,
47    silent_on_network_failure: bool,
48}
49
50impl NetworkClient {
51    #[must_use]
52    pub fn new(
53        sdk_key: &str,
54        headers: Option<HashMap<String, String>>,
55        disable_network: Option<bool>,
56    ) -> Self {
57        let net_provider = get_network_provider();
58
59        NetworkClient {
60            headers: headers.unwrap_or_default(),
61            is_shutdown: Arc::new(AtomicBool::new(false)),
62            net_provider,
63            ops_stats: OPS_STATS.get_for_instance(sdk_key),
64            disable_network: disable_network.unwrap_or_default(),
65            silent_on_network_failure: false,
66        }
67    }
68
69    pub fn shutdown(&self) {
70        self.is_shutdown.store(true, Ordering::SeqCst);
71    }
72
73    pub async fn get(&self, request_args: RequestArgs) -> Result<Response, NetworkError> {
74        self.make_request(HttpMethod::GET, request_args).await
75    }
76
77    pub async fn post(
78        &self,
79        mut request_args: RequestArgs,
80        body: Option<Vec<u8>>,
81    ) -> Result<Response, NetworkError> {
82        request_args.body = body;
83        self.make_request(HttpMethod::POST, request_args).await
84    }
85
86    async fn make_request(
87        &self,
88        method: HttpMethod,
89        mut request_args: RequestArgs,
90    ) -> Result<Response, NetworkError> {
91        let is_shutdown = if let Some(is_shutdown) = &request_args.is_shutdown {
92            is_shutdown.clone()
93        } else {
94            self.is_shutdown.clone()
95        };
96
97        if self.disable_network {
98            log_d!(TAG, "Network is disabled, not making requests");
99            return Err(NetworkError::DisableNetworkOn);
100        }
101
102        request_args.populate_headers(self.headers.clone());
103
104        let mut merged_headers = request_args.headers.unwrap_or_default();
105        if !self.headers.is_empty() {
106            merged_headers.extend(self.headers.clone());
107        }
108        merged_headers.insert(
109            "STATSIG-CLIENT-TIME".into(),
110            Utc::now().timestamp_millis().to_string(),
111        );
112        request_args.headers = Some(merged_headers);
113
114        let mut attempt = 0;
115
116        loop {
117            if let Some(key) = request_args.diagnostics_key {
118                self.ops_stats.add_marker(
119                    Marker::new(key, ActionType::Start, Some(StepType::NetworkRequest))
120                        .with_attempt(attempt)
121                        .with_url(request_args.url.clone()),
122                    None,
123                );
124            }
125            if is_shutdown.load(Ordering::SeqCst) {
126                log_i!(TAG, "{}", SHUTDOWN_ERROR);
127                return Err(NetworkError::ShutdownError);
128            }
129
130            let response = match self.net_provider.upgrade() {
131                Some(net_provider) => net_provider.send(&method, &request_args).await,
132                None => return Err(NetworkError::RequestFailed),
133            };
134
135            log_d!(
136                TAG,
137                "Response ({}): {}",
138                &request_args.url,
139                response.status_code
140            );
141
142            let status = response.status_code;
143            let sdk_region_str = response
144                .headers
145                .as_ref()
146                .and_then(|h| h.get("x-statsig-region"));
147            let success = (200..300).contains(&status);
148
149            let error_message = response
150                .error
151                .clone()
152                .unwrap_or_else(|| get_error_message_for_status(status));
153
154            if let Some(key) = request_args.diagnostics_key {
155                let mut end_marker =
156                    Marker::new(key, ActionType::End, Some(StepType::NetworkRequest))
157                        .with_attempt(attempt)
158                        .with_url(request_args.url.clone())
159                        .with_status_code(status)
160                        .with_is_success(success)
161                        .with_sdk_region(sdk_region_str.map(|s| s.to_owned()));
162
163                let error_map = if !error_message.is_empty() {
164                    let mut map = HashMap::new();
165                    map.insert("name".to_string(), "NetworkError".to_string());
166                    map.insert("message".to_string(), error_message.clone());
167                    map.insert("code".to_string(), status.to_string());
168                    Some(map)
169                } else {
170                    None
171                };
172
173                if let Some(error_map) = error_map {
174                    end_marker = end_marker.with_error(error_map);
175                }
176
177                self.ops_stats.add_marker(end_marker, None);
178            }
179
180            if success {
181                return Ok(response);
182            }
183
184            if !RETRY_CODES.contains(&status) {
185                self.log_warning(
186                    format!("status:{} message:{}", status, error_message),
187                    &request_args,
188                );
189                return Err(NetworkError::RequestFailed);
190            }
191
192            if attempt >= request_args.retries {
193                self.log_warning(
194                    format!(
195                        "Network error, retries exhausted: {} {}",
196                        status, error_message
197                    ),
198                    &request_args,
199                );
200                return Err(NetworkError::RetriesExhausted);
201            }
202
203            attempt += 1;
204            let backoff_ms = 2_u64.pow(attempt) * 100;
205
206            log_w!(
207                TAG, "Network request failed with status code {} (attempt {}), will retry after {}ms...\n{}",
208                status,
209                attempt,
210                backoff_ms,
211                error_message
212            );
213
214            tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
215        }
216    }
217
218    pub fn mute_network_error_log(mut self) -> Self {
219        self.silent_on_network_failure = true;
220        self
221    }
222
223    fn log_warning(&self, message: String, args: &RequestArgs) {
224        log_w!(TAG, "{}", message);
225        if !self.silent_on_network_failure {
226            let dedupe_key = format!("{:?}", args.diagnostics_key);
227            self.ops_stats.log_error(ErrorBoundaryEvent {
228                tag: TAG.to_string(),
229                bypass_dedupe: false,
230                exception: message,
231                dedupe_key: Some(dedupe_key),
232                extra: None,
233            });
234        }
235    }
236}
237
238fn get_error_message_for_status(status: u16) -> String {
239    if (200..300).contains(&status) {
240        return String::new();
241    }
242
243    match status {
244        400 => "Bad Request".to_string(),
245        401 => "Unauthorized".to_string(),
246        403 => "Forbidden".to_string(),
247        404 => "Not Found".to_string(),
248        405 => "Method Not Allowed".to_string(),
249        406 => "Not Acceptable".to_string(),
250        408 => "Request Timeout".to_string(),
251        500 => "Internal Server Error".to_string(),
252        502 => "Bad Gateway".to_string(),
253        503 => "Service Unavailable".to_string(),
254        504 => "Gateway Timeout".to_string(),
255        0 => "Unknown Error".to_string(),
256        _ => format!("HTTP Error {status}"),
257    }
258}