statsig_rust/networking/
network_client.rs

1use chrono::Utc;
2use serde::Serialize;
3
4use super::providers::get_network_provider;
5use super::{HttpMethod, NetworkProvider, RequestArgs, Response};
6use crate::observability::ops_stats::{OpsStatsForInstance, OPS_STATS};
7use crate::observability::ErrorBoundaryEvent;
8use crate::sdk_diagnostics::marker::{ActionType, Marker, StepType};
9use crate::{log_d, log_i, log_w, StatsigErr, StatsigOptions};
10use std::collections::HashMap;
11use std::fmt;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::{Arc, Weak};
14use std::time::Duration;
15
16const RETRY_CODES: [u16; 8] = [408, 500, 502, 503, 504, 522, 524, 599];
17const SHUTDOWN_ERROR: &str = "Request was aborted because the client is shutting down";
18
19#[derive(PartialEq, Debug, Clone, Serialize)]
20pub enum NetworkError {
21    ShutdownError,
22    RequestFailed,
23    RetriesExhausted,
24    SerializationError(String),
25    DisableNetworkOn,
26    RequestNotRetryable,
27}
28
29impl fmt::Display for NetworkError {
30    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
31        match self {
32            NetworkError::ShutdownError => write!(f, "ShutdownError"),
33            NetworkError::RequestFailed => write!(f, "RequestFailed"),
34            NetworkError::RetriesExhausted => write!(f, "RetriesExhausted"),
35            NetworkError::SerializationError(s) => write!(f, "SerializationError: {s}"),
36            NetworkError::DisableNetworkOn => write!(f, "DisableNetworkOn"),
37            NetworkError::RequestNotRetryable => write!(f, "RequestNotRetryable"),
38        }
39    }
40}
41
42const TAG: &str = stringify!(NetworkClient);
43
44pub struct NetworkClient {
45    headers: HashMap<String, String>,
46    is_shutdown: Arc<AtomicBool>,
47    ops_stats: Arc<OpsStatsForInstance>,
48    net_provider: Weak<dyn NetworkProvider>,
49    disable_network: bool,
50    silent_on_network_failure: bool,
51}
52
53impl NetworkClient {
54    #[must_use]
55    pub fn new(
56        sdk_key: &str,
57        headers: Option<HashMap<String, String>>,
58        options: Option<&StatsigOptions>,
59    ) -> Self {
60        let net_provider = get_network_provider();
61        let disable_network: bool = options
62            .and_then(|opts| opts.disable_network)
63            .unwrap_or(false);
64
65        NetworkClient {
66            headers: headers.unwrap_or_default(),
67            is_shutdown: Arc::new(AtomicBool::new(false)),
68            net_provider,
69            ops_stats: OPS_STATS.get_for_instance(sdk_key),
70            disable_network,
71            silent_on_network_failure: false,
72        }
73    }
74
75    pub fn shutdown(&self) {
76        self.is_shutdown.store(true, Ordering::SeqCst);
77    }
78
79    pub async fn get(&self, request_args: RequestArgs) -> Result<Response, NetworkError> {
80        self.make_request(HttpMethod::GET, request_args).await
81    }
82
83    pub async fn post(
84        &self,
85        mut request_args: RequestArgs,
86        body: Option<Vec<u8>>,
87    ) -> Result<Response, NetworkError> {
88        request_args.body = body;
89        self.make_request(HttpMethod::POST, request_args).await
90    }
91
92    async fn make_request(
93        &self,
94        method: HttpMethod,
95        mut request_args: RequestArgs,
96    ) -> Result<Response, NetworkError> {
97        let is_shutdown = if let Some(is_shutdown) = &request_args.is_shutdown {
98            is_shutdown.clone()
99        } else {
100            self.is_shutdown.clone()
101        };
102
103        if self.disable_network {
104            log_d!(TAG, "Network is disabled, not making requests");
105            return Err(NetworkError::DisableNetworkOn);
106        }
107
108        request_args.populate_headers(self.headers.clone());
109
110        let mut merged_headers = request_args.headers.unwrap_or_default();
111        if !self.headers.is_empty() {
112            merged_headers.extend(self.headers.clone());
113        }
114        merged_headers.insert(
115            "STATSIG-CLIENT-TIME".into(),
116            Utc::now().timestamp_millis().to_string(),
117        );
118        request_args.headers = Some(merged_headers);
119
120        let mut attempt = 0;
121
122        loop {
123            if let Some(key) = request_args.diagnostics_key {
124                self.ops_stats.add_marker(
125                    Marker::new(key, ActionType::Start, Some(StepType::NetworkRequest))
126                        .with_attempt(attempt)
127                        .with_url(request_args.url.clone()),
128                    None,
129                );
130            }
131            if is_shutdown.load(Ordering::SeqCst) {
132                log_i!(TAG, "{}", SHUTDOWN_ERROR);
133                return Err(NetworkError::ShutdownError);
134            }
135
136            let response = match self.net_provider.upgrade() {
137                Some(net_provider) => net_provider.send(&method, &request_args).await,
138                None => return Err(NetworkError::RequestFailed),
139            };
140
141            log_d!(
142                TAG,
143                "Response ({}): {}",
144                &request_args.url,
145                response.status_code
146            );
147
148            let status = response.status_code;
149            let sdk_region_str = response
150                .headers
151                .as_ref()
152                .and_then(|h| h.get("x-statsig-region"));
153            let success = (200..300).contains(&status);
154
155            let error_message = response
156                .error
157                .clone()
158                .unwrap_or_else(|| get_error_message_for_status(status));
159
160            if let Some(key) = request_args.diagnostics_key {
161                let mut end_marker =
162                    Marker::new(key, ActionType::End, Some(StepType::NetworkRequest))
163                        .with_attempt(attempt)
164                        .with_url(request_args.url.clone())
165                        .with_status_code(status)
166                        .with_is_success(success)
167                        .with_sdk_region(sdk_region_str.map(|s| s.to_owned()));
168
169                let error_map = if !error_message.is_empty() {
170                    let mut map = HashMap::new();
171                    map.insert("name".to_string(), "NetworkError".to_string());
172                    map.insert("message".to_string(), error_message.clone());
173                    map.insert("code".to_string(), status.to_string());
174                    Some(map)
175                } else {
176                    None
177                };
178
179                if let Some(error_map) = error_map {
180                    end_marker = end_marker.with_error(error_map);
181                }
182
183                self.ops_stats.add_marker(end_marker, None);
184            }
185
186            if success {
187                return Ok(response);
188            }
189
190            if !RETRY_CODES.contains(&status) {
191                let msg = format!("Network error, not retrying: {} {}", status, error_message);
192                self.log_warning(
193                    StatsigErr::NetworkError(NetworkError::RequestNotRetryable, Some(msg)),
194                    &request_args,
195                );
196                return Err(NetworkError::RequestNotRetryable);
197            }
198
199            if attempt >= request_args.retries {
200                let msg = format!(
201                    "Network error, retries exhausted: {} {}",
202                    status, error_message
203                );
204                self.log_warning(
205                    StatsigErr::NetworkError(NetworkError::RetriesExhausted, Some(msg)),
206                    &request_args,
207                );
208                return Err(NetworkError::RetriesExhausted);
209            }
210
211            attempt += 1;
212            let backoff_ms = 2_u64.pow(attempt) * 100;
213
214            log_w!(
215                TAG, "Network request failed with status code {} (attempt {}), will retry after {}ms...\n{}",
216                status,
217                attempt,
218                backoff_ms,
219                error_message
220            );
221
222            tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
223        }
224    }
225
226    pub fn mute_network_error_log(mut self) -> Self {
227        self.silent_on_network_failure = true;
228        self
229    }
230
231    fn log_warning(&self, error: StatsigErr, args: &RequestArgs) {
232        log_w!(TAG, "{}", error);
233        if !self.silent_on_network_failure {
234            let dedupe_key = format!("{:?}", args.diagnostics_key);
235            self.ops_stats.log_error(ErrorBoundaryEvent {
236                tag: TAG.to_string(),
237                bypass_dedupe: false,
238                info: error,
239                dedupe_key: Some(dedupe_key),
240                extra: None,
241            });
242        }
243    }
244}
245
246fn get_error_message_for_status(status: u16) -> String {
247    if (200..300).contains(&status) {
248        return String::new();
249    }
250
251    match status {
252        400 => "Bad Request".to_string(),
253        401 => "Unauthorized".to_string(),
254        403 => "Forbidden".to_string(),
255        404 => "Not Found".to_string(),
256        405 => "Method Not Allowed".to_string(),
257        406 => "Not Acceptable".to_string(),
258        408 => "Request Timeout".to_string(),
259        500 => "Internal Server Error".to_string(),
260        502 => "Bad Gateway".to_string(),
261        503 => "Service Unavailable".to_string(),
262        504 => "Gateway Timeout".to_string(),
263        0 => "Unknown Error".to_string(),
264        _ => format!("HTTP Error {status}"),
265    }
266}