statsig_rust/networking/
network_client.rs

1use chrono::Utc;
2
3use super::providers::get_network_provider;
4use super::{HttpMethod, NetworkProvider, RequestArgsTyped};
5use crate::observability::ops_stats::{OpsStatsForInstance, OPS_STATS};
6use crate::observability::ErrorBoundaryEvent;
7use crate::sdk_diagnostics::marker::{ActionType, Marker, StepType};
8use crate::{log_d, log_error_to_statsig_and_console, log_i, log_w, StatsigErr};
9use std::collections::HashMap;
10use std::fmt;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::{Arc, Weak};
13use std::time::Duration;
14
15const RETRY_CODES: [u16; 8] = [408, 500, 502, 503, 504, 522, 524, 599];
16const SHUTDOWN_ERROR: &str = "Request was aborted because the client is shutting down";
17
18#[derive(PartialEq, Debug, Clone)]
19pub enum NetworkError {
20    ShutdownError,
21    RequestFailed,
22    RetriesExhausted,
23    SerializationError(String),
24    DisableNetworkOn,
25}
26
27impl fmt::Display for NetworkError {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        match self {
30            NetworkError::ShutdownError => write!(f, "ShutdownError"),
31            NetworkError::RequestFailed => write!(f, "RequestFailed"),
32            NetworkError::RetriesExhausted => write!(f, "RetriesExhausted"),
33            NetworkError::SerializationError(s) => write!(f, "SerializationError: {s}"),
34            NetworkError::DisableNetworkOn => write!(f, "DisableNetworkOn"),
35        }
36    }
37}
38
39const TAG: &str = stringify!(NetworkClient);
40
41pub struct NetworkClient {
42    headers: HashMap<String, String>,
43    is_shutdown: Arc<AtomicBool>,
44    ops_stats: Arc<OpsStatsForInstance>,
45    net_provider: Weak<dyn NetworkProvider>,
46    disable_network: bool,
47}
48
49impl NetworkClient {
50    #[must_use]
51    pub fn new(
52        sdk_key: &str,
53        headers: Option<HashMap<String, String>>,
54        disable_network: Option<bool>,
55    ) -> Self {
56        let net_provider = get_network_provider();
57
58        NetworkClient {
59            headers: headers.unwrap_or_default(),
60            is_shutdown: Arc::new(AtomicBool::new(false)),
61            net_provider,
62            ops_stats: OPS_STATS.get_for_instance(sdk_key),
63            disable_network: disable_network.unwrap_or_default(),
64        }
65    }
66
67    pub fn shutdown(&self) {
68        self.is_shutdown.store(true, Ordering::SeqCst);
69    }
70
71    pub async fn get<T, D>(&self, request_args: RequestArgsTyped<T, D>) -> Result<T, NetworkError>
72    where
73        T: Clone,
74        D: Fn(Option<&Vec<u8>>) -> Result<T, StatsigErr> + Clone,
75    {
76        self.make_request(HttpMethod::GET, request_args).await
77    }
78
79    pub async fn post<T, D>(
80        &self,
81        mut request_args: RequestArgsTyped<T, D>,
82        body: Option<Vec<u8>>,
83    ) -> Result<T, NetworkError>
84    where
85        T: Clone,
86        D: Fn(Option<&Vec<u8>>) -> Result<T, StatsigErr> + Clone,
87    {
88        request_args.body = body;
89        self.make_request(HttpMethod::POST, request_args).await
90    }
91
92    async fn make_request<T, D>(
93        &self,
94        method: HttpMethod,
95        mut request_args: RequestArgsTyped<T, D>,
96    ) -> Result<T, NetworkError>
97    where
98        T: Clone,
99        D: Fn(Option<&Vec<u8>>) -> Result<T, StatsigErr> + Clone,
100    {
101        let is_shutdown = if let Some(is_shutdown) = &request_args.is_shutdown {
102            is_shutdown.clone()
103        } else {
104            self.is_shutdown.clone()
105        };
106
107        if self.disable_network {
108            log_d!(TAG, "Network is disabled, not making requests");
109            return Err(NetworkError::DisableNetworkOn);
110        }
111
112        request_args.populate_headers(self.headers.clone());
113
114        let mut merged_headers = request_args.headers.unwrap_or_default();
115        if !self.headers.is_empty() {
116            merged_headers.extend(self.headers.clone());
117        }
118        merged_headers.insert(
119            "STATSIG-CLIENT-TIME".into(),
120            Utc::now().timestamp_millis().to_string(),
121        );
122        request_args.headers = Some(merged_headers);
123
124        let mut attempt = 0;
125
126        loop {
127            let request_args_clone = request_args.clone();
128            if let Some(key) = request_args_clone.diagnostics_key {
129                self.ops_stats.add_marker(
130                    Marker::new(key, ActionType::Start, Some(StepType::NetworkRequest))
131                        .with_attempt(attempt)
132                        .with_url(request_args.url.clone()),
133                    None,
134                );
135            }
136            if is_shutdown.load(Ordering::SeqCst) {
137                log_i!(TAG, "{}", SHUTDOWN_ERROR);
138                return Err(NetworkError::ShutdownError);
139            }
140
141            let response = match self.net_provider.upgrade() {
142                Some(net_provider) => net_provider.send(&method, &request_args_clone.into()).await,
143                None => return Err(NetworkError::RequestFailed),
144            };
145
146            log_d!(
147                TAG,
148                "Response ({}): {}",
149                &request_args.url,
150                response.status_code
151            );
152
153            let status = response.status_code;
154            let sdk_region_str = response
155                .headers
156                .as_ref()
157                .and_then(|h| h.get("x-statsig-region"));
158            let success = (200..300).contains(&status);
159
160            let error_message = response
161                .error
162                .unwrap_or_else(|| get_error_message_for_status(status));
163
164            if let Some(key) = request_args.diagnostics_key {
165                let mut end_marker =
166                    Marker::new(key, ActionType::End, Some(StepType::NetworkRequest))
167                        .with_attempt(attempt)
168                        .with_url(request_args.url.clone())
169                        .with_status_code(status)
170                        .with_is_success(success)
171                        .with_sdk_region(sdk_region_str.map(|s| s.to_owned()));
172
173                let error_map = if !error_message.is_empty() {
174                    let mut map = HashMap::new();
175                    map.insert("name".to_string(), "NetworkError".to_string());
176                    map.insert("message".to_string(), error_message.clone());
177                    map.insert("code".to_string(), status.to_string());
178                    Some(map)
179                } else {
180                    None
181                };
182
183                if let Some(error_map) = error_map {
184                    end_marker = end_marker.with_error(error_map);
185                }
186
187                self.ops_stats.add_marker(end_marker, None);
188            }
189
190            if success {
191                return (request_args.response_deserializer)(response.data.as_ref())
192                    .map_err(|e| NetworkError::SerializationError(e.to_string()));
193            }
194
195            if !RETRY_CODES.contains(&status) {
196                log_error_to_statsig_and_console!(
197                    &self.ops_stats,
198                    TAG,
199                    "status:{} message:{}",
200                    status,
201                    error_message
202                );
203                return Err(NetworkError::RequestFailed);
204            }
205
206            if attempt >= request_args.retries {
207                log_error_to_statsig_and_console!(
208                    &self.ops_stats,
209                    TAG,
210                    "Network error, retries exhausted: {} {}",
211                    status,
212                    error_message
213                );
214                return Err(NetworkError::RetriesExhausted);
215            }
216
217            attempt += 1;
218            let backoff_ms = 2_u64.pow(attempt) * 100;
219
220            log_w!(
221                TAG, "Network request failed with status code {} (attempt {}), will retry after {}ms...\n{}",
222                status,
223                attempt,
224                backoff_ms,
225                error_message
226            );
227
228            tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
229        }
230    }
231}
232
233fn get_error_message_for_status(status: u16) -> String {
234    if (200..300).contains(&status) {
235        return String::new();
236    }
237
238    match status {
239        400 => "Bad Request".to_string(),
240        401 => "Unauthorized".to_string(),
241        403 => "Forbidden".to_string(),
242        404 => "Not Found".to_string(),
243        405 => "Method Not Allowed".to_string(),
244        406 => "Not Acceptable".to_string(),
245        408 => "Request Timeout".to_string(),
246        500 => "Internal Server Error".to_string(),
247        502 => "Bad Gateway".to_string(),
248        503 => "Service Unavailable".to_string(),
249        504 => "Gateway Timeout".to_string(),
250        0 => "Unknown Error".to_string(),
251        _ => format!("HTTP Error {status}"),
252    }
253}