statsig_rust/networking/
network_client.rs

1use chrono::Utc;
2use serde::Serialize;
3
4use super::providers::get_network_provider;
5use super::{HttpMethod, NetworkProvider, RequestArgs, Response};
6use crate::networking::proxy_config::ProxyConfig;
7use crate::observability::ops_stats::{OpsStatsForInstance, OPS_STATS};
8use crate::observability::ErrorBoundaryEvent;
9use crate::sdk_diagnostics::marker::{ActionType, Marker, StepType};
10use crate::{log_d, log_i, log_w, StatsigErr, StatsigOptions};
11use std::collections::HashMap;
12use std::fmt;
13use std::sync::atomic::{AtomicBool, Ordering};
14use std::sync::{Arc, Weak};
15use std::time::Duration;
16
17const RETRY_CODES: [u16; 8] = [408, 500, 502, 503, 504, 522, 524, 599];
18const SHUTDOWN_ERROR: &str = "Request was aborted because the client is shutting down";
19
20#[derive(PartialEq, Debug, Clone, Serialize)]
21pub enum NetworkError {
22    ShutdownError,
23    RequestFailed,
24    RetriesExhausted,
25    SerializationError(String),
26    DisableNetworkOn,
27    RequestNotRetryable,
28}
29
30impl fmt::Display for NetworkError {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        match self {
33            NetworkError::ShutdownError => write!(f, "ShutdownError"),
34            NetworkError::RequestFailed => write!(f, "RequestFailed"),
35            NetworkError::RetriesExhausted => write!(f, "RetriesExhausted"),
36            NetworkError::SerializationError(s) => write!(f, "SerializationError: {s}"),
37            NetworkError::DisableNetworkOn => write!(f, "DisableNetworkOn"),
38            NetworkError::RequestNotRetryable => write!(f, "RequestNotRetryable"),
39        }
40    }
41}
42
43const TAG: &str = stringify!(NetworkClient);
44
45pub struct NetworkClient {
46    headers: HashMap<String, String>,
47    is_shutdown: Arc<AtomicBool>,
48    ops_stats: Arc<OpsStatsForInstance>,
49    net_provider: Weak<dyn NetworkProvider>,
50    disable_network: bool,
51    proxy_config: Option<ProxyConfig>,
52    silent_on_network_failure: bool,
53}
54
55impl NetworkClient {
56    #[must_use]
57    pub fn new(
58        sdk_key: &str,
59        headers: Option<HashMap<String, String>>,
60        options: Option<&StatsigOptions>,
61    ) -> Self {
62        let net_provider = get_network_provider();
63        let (disable_network, proxy_config) = options
64            .map(|opts| {
65                (
66                    opts.disable_network.unwrap_or(false),
67                    opts.proxy_config.clone(),
68                )
69            })
70            .unwrap_or((false, None));
71
72        NetworkClient {
73            headers: headers.unwrap_or_default(),
74            is_shutdown: Arc::new(AtomicBool::new(false)),
75            net_provider,
76            ops_stats: OPS_STATS.get_for_instance(sdk_key),
77            disable_network,
78            proxy_config,
79            silent_on_network_failure: false,
80        }
81    }
82
83    pub fn shutdown(&self) {
84        self.is_shutdown.store(true, Ordering::SeqCst);
85    }
86
87    pub async fn get(&self, request_args: RequestArgs) -> Result<Response, NetworkError> {
88        self.make_request(HttpMethod::GET, request_args).await
89    }
90
91    pub async fn post(
92        &self,
93        mut request_args: RequestArgs,
94        body: Option<Vec<u8>>,
95    ) -> Result<Response, NetworkError> {
96        request_args.body = body;
97        self.make_request(HttpMethod::POST, request_args).await
98    }
99
100    async fn make_request(
101        &self,
102        method: HttpMethod,
103        mut request_args: RequestArgs,
104    ) -> Result<Response, NetworkError> {
105        let is_shutdown = if let Some(is_shutdown) = &request_args.is_shutdown {
106            is_shutdown.clone()
107        } else {
108            self.is_shutdown.clone()
109        };
110
111        if self.disable_network {
112            log_d!(TAG, "Network is disabled, not making requests");
113            return Err(NetworkError::DisableNetworkOn);
114        }
115
116        request_args.populate_headers(self.headers.clone());
117
118        let mut merged_headers = request_args.headers.unwrap_or_default();
119        if !self.headers.is_empty() {
120            merged_headers.extend(self.headers.clone());
121        }
122        merged_headers.insert(
123            "STATSIG-CLIENT-TIME".into(),
124            Utc::now().timestamp_millis().to_string(),
125        );
126        request_args.headers = Some(merged_headers);
127
128        // passing down proxy config through request args
129        if let Some(proxy_config) = &self.proxy_config {
130            request_args.proxy_config = Some(proxy_config.clone());
131        }
132
133        let mut attempt = 0;
134
135        loop {
136            if let Some(key) = request_args.diagnostics_key {
137                self.ops_stats.add_marker(
138                    Marker::new(key, ActionType::Start, Some(StepType::NetworkRequest))
139                        .with_attempt(attempt)
140                        .with_url(request_args.url.clone()),
141                    None,
142                );
143            }
144            if is_shutdown.load(Ordering::SeqCst) {
145                log_i!(TAG, "{}", SHUTDOWN_ERROR);
146                return Err(NetworkError::ShutdownError);
147            }
148
149            let response = match self.net_provider.upgrade() {
150                Some(net_provider) => net_provider.send(&method, &request_args).await,
151                None => return Err(NetworkError::RequestFailed),
152            };
153
154            log_d!(
155                TAG,
156                "Response ({}): {}",
157                &request_args.url,
158                response.status_code
159            );
160
161            let status = response.status_code;
162            let sdk_region_str = response
163                .headers
164                .as_ref()
165                .and_then(|h| h.get("x-statsig-region"));
166            let success = (200..300).contains(&status);
167
168            let error_message = response
169                .error
170                .clone()
171                .unwrap_or_else(|| get_error_message_for_status(status));
172
173            if let Some(key) = request_args.diagnostics_key {
174                let mut end_marker =
175                    Marker::new(key, ActionType::End, Some(StepType::NetworkRequest))
176                        .with_attempt(attempt)
177                        .with_url(request_args.url.clone())
178                        .with_status_code(status)
179                        .with_is_success(success)
180                        .with_sdk_region(sdk_region_str.map(|s| s.to_owned()));
181
182                let error_map = if !error_message.is_empty() {
183                    let mut map = HashMap::new();
184                    map.insert("name".to_string(), "NetworkError".to_string());
185                    map.insert("message".to_string(), error_message.clone());
186                    map.insert("code".to_string(), status.to_string());
187                    Some(map)
188                } else {
189                    None
190                };
191
192                if let Some(error_map) = error_map {
193                    end_marker = end_marker.with_error(error_map);
194                }
195
196                self.ops_stats.add_marker(end_marker, None);
197            }
198
199            if success {
200                return Ok(response);
201            }
202
203            if !RETRY_CODES.contains(&status) {
204                let msg = format!("Network error, not retrying: {} {}", status, error_message);
205                self.log_warning(
206                    StatsigErr::NetworkError(NetworkError::RequestNotRetryable, Some(msg)),
207                    &request_args,
208                );
209                return Err(NetworkError::RequestNotRetryable);
210            }
211
212            if attempt >= request_args.retries {
213                let msg = format!(
214                    "Network error, retries exhausted: {} {}",
215                    status, error_message
216                );
217                self.log_warning(
218                    StatsigErr::NetworkError(NetworkError::RetriesExhausted, Some(msg)),
219                    &request_args,
220                );
221                return Err(NetworkError::RetriesExhausted);
222            }
223
224            attempt += 1;
225            let backoff_ms = 2_u64.pow(attempt) * 100;
226
227            log_w!(
228                TAG, "Network request failed with status code {} (attempt {}), will retry after {}ms...\n{}",
229                status,
230                attempt,
231                backoff_ms,
232                error_message
233            );
234
235            tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
236        }
237    }
238
239    pub fn mute_network_error_log(mut self) -> Self {
240        self.silent_on_network_failure = true;
241        self
242    }
243
244    fn log_warning(&self, error: StatsigErr, args: &RequestArgs) {
245        log_w!(TAG, "{}", error);
246        if !self.silent_on_network_failure {
247            let dedupe_key = format!("{:?}", args.diagnostics_key);
248            self.ops_stats.log_error(ErrorBoundaryEvent {
249                tag: TAG.to_string(),
250                bypass_dedupe: false,
251                info: error,
252                dedupe_key: Some(dedupe_key),
253                extra: None,
254            });
255        }
256    }
257}
258
259fn get_error_message_for_status(status: u16) -> String {
260    if (200..300).contains(&status) {
261        return String::new();
262    }
263
264    match status {
265        400 => "Bad Request".to_string(),
266        401 => "Unauthorized".to_string(),
267        403 => "Forbidden".to_string(),
268        404 => "Not Found".to_string(),
269        405 => "Method Not Allowed".to_string(),
270        406 => "Not Acceptable".to_string(),
271        408 => "Request Timeout".to_string(),
272        500 => "Internal Server Error".to_string(),
273        502 => "Bad Gateway".to_string(),
274        503 => "Service Unavailable".to_string(),
275        504 => "Gateway Timeout".to_string(),
276        0 => "Unknown Error".to_string(),
277        _ => format!("HTTP Error {status}"),
278    }
279}