statsig_rust/networking/
network_client.rs

1use chrono::Utc;
2
3use super::providers::get_network_provider;
4use super::{HttpMethod, NetworkProvider, RequestArgsTyped};
5use crate::observability::ops_stats::{OpsStatsForInstance, OPS_STATS};
6use crate::observability::ErrorBoundaryEvent;
7use crate::sdk_diagnostics::marker::{ActionType, Marker, StepType};
8use crate::{log_d, log_error_to_statsig_and_console, log_i, log_w};
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::{Arc, Weak};
12use std::time::Duration;
13
14const RETRY_CODES: [u16; 8] = [408, 500, 502, 503, 504, 522, 524, 599];
15const SHUTDOWN_ERROR: &str = "Request was aborted because the client is shutting down";
16
17#[derive(PartialEq, Debug)]
18pub enum NetworkError {
19    ShutdownError,
20    RequestFailed,
21    RetriesExhausted,
22    SerializationError(String),
23}
24const TAG: &str = stringify!(NetworkClient);
25
26pub struct NetworkClient {
27    headers: HashMap<String, String>,
28    is_shutdown: Arc<AtomicBool>,
29    ops_stats: Arc<OpsStatsForInstance>,
30    net_provider: Weak<dyn NetworkProvider>,
31}
32
33impl NetworkClient {
34    #[must_use]
35    pub fn new(sdk_key: &str, headers: Option<HashMap<String, String>>) -> Self {
36        let net_provider = get_network_provider();
37
38        NetworkClient {
39            headers: headers.unwrap_or_default(),
40            is_shutdown: Arc::new(AtomicBool::new(false)),
41            net_provider,
42            ops_stats: OPS_STATS.get_for_instance(sdk_key),
43        }
44    }
45
46    pub fn shutdown(&self) {
47        self.is_shutdown.store(true, Ordering::SeqCst);
48    }
49
50    pub async fn get<T>(&self, request_args: RequestArgsTyped<T>) -> Result<T, NetworkError>
51    where
52        T: Clone,
53    {
54        self.make_request(HttpMethod::GET, request_args).await
55    }
56
57    pub async fn post<T>(
58        &self,
59        mut request_args: RequestArgsTyped<T>,
60        body: Option<Vec<u8>>,
61    ) -> Result<T, NetworkError>
62    where
63        T: Clone,
64    {
65        request_args.body = body;
66        self.make_request(HttpMethod::POST, request_args).await
67    }
68
69    async fn make_request<T>(
70        &self,
71        method: HttpMethod,
72        mut request_args: RequestArgsTyped<T>,
73    ) -> Result<T, NetworkError>
74    where
75        T: Clone,
76    {
77        let is_shutdown = if let Some(is_shutdown) = &request_args.is_shutdown {
78            is_shutdown.clone()
79        } else {
80            self.is_shutdown.clone()
81        };
82
83        request_args.populate_headers(self.headers.clone());
84
85        let mut merged_headers = request_args.headers.unwrap_or_default();
86        if !self.headers.is_empty() {
87            merged_headers.extend(self.headers.clone());
88        }
89        merged_headers.insert(
90            "STATSIG-CLIENT-TIME".into(),
91            Utc::now().timestamp_millis().to_string(),
92        );
93        request_args.headers = Some(merged_headers);
94
95        let mut attempt = 0;
96
97        loop {
98            let request_args_clone = request_args.clone();
99            if let Some(key) = request_args_clone.diagnostics_key {
100                self.ops_stats.add_marker(
101                    Marker::new(key, ActionType::Start, Some(StepType::NetworkRequest))
102                        .with_attempt(attempt)
103                        .with_url(request_args.url.clone()),
104                    None,
105                );
106            }
107            if is_shutdown.load(Ordering::SeqCst) {
108                log_i!(TAG, "{}", SHUTDOWN_ERROR);
109                return Err(NetworkError::ShutdownError);
110            }
111
112            let response = match self.net_provider.upgrade() {
113                Some(net_provider) => net_provider.send(&method, &request_args_clone.into()).await,
114                None => return Err(NetworkError::RequestFailed),
115            };
116
117            log_d!(
118                TAG,
119                "Response ({}): {}",
120                &request_args.url,
121                response.status_code
122            );
123
124            let status = response.status_code;
125            let sdk_region_str = response
126                .headers
127                .as_ref()
128                .and_then(|h| h.get("x-statsig-region"));
129            let success = (200..300).contains(&status);
130
131            let error_message = response
132                .error
133                .unwrap_or_else(|| get_error_message_for_status(status));
134
135            if let Some(key) = request_args.diagnostics_key {
136                let mut end_marker =
137                    Marker::new(key, ActionType::End, Some(StepType::NetworkRequest))
138                        .with_attempt(attempt)
139                        .with_url(request_args.url.clone())
140                        .with_status_code(status)
141                        .with_is_success(success)
142                        .with_sdk_region(sdk_region_str.map(|s| s.to_owned()));
143
144                let error_map = if !error_message.is_empty() {
145                    let mut map = HashMap::new();
146                    map.insert("name".to_string(), "NetworkError".to_string());
147                    map.insert("message".to_string(), error_message.clone());
148                    map.insert("code".to_string(), status.to_string());
149                    Some(map)
150                } else {
151                    None
152                };
153
154                if let Some(error_map) = error_map {
155                    end_marker = end_marker.with_error(error_map);
156                }
157
158                self.ops_stats.add_marker(end_marker, None);
159            }
160
161            if success {
162                return (request_args.response_deserializer)(response.data)
163                    .map_err(|e| NetworkError::SerializationError(e.to_string()));
164            }
165
166            if !RETRY_CODES.contains(&status) {
167                log_error_to_statsig_and_console!(
168                    &self.ops_stats,
169                    TAG,
170                    "status:{} message:{}",
171                    status,
172                    error_message
173                );
174                return Err(NetworkError::RequestFailed);
175            }
176
177            if attempt >= request_args.retries {
178                log_error_to_statsig_and_console!(
179                    &self.ops_stats,
180                    TAG,
181                    "Network error, retries exhausted: {} {}",
182                    status,
183                    error_message
184                );
185                return Err(NetworkError::RetriesExhausted);
186            }
187
188            attempt += 1;
189            let backoff_ms = 2_u64.pow(attempt) * 100;
190
191            log_w!(
192                TAG, "Network request failed with status code {} (attempt {}), will retry after {}ms...\n{}",
193                status,
194                attempt,
195                backoff_ms,
196                error_message
197            );
198
199            tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
200        }
201    }
202}
203
204fn get_error_message_for_status(status: u16) -> String {
205    if (200..300).contains(&status) {
206        return String::new();
207    }
208
209    match status {
210        400 => "Bad Request".to_string(),
211        401 => "Unauthorized".to_string(),
212        403 => "Forbidden".to_string(),
213        404 => "Not Found".to_string(),
214        405 => "Method Not Allowed".to_string(),
215        406 => "Not Acceptable".to_string(),
216        408 => "Request Timeout".to_string(),
217        500 => "Internal Server Error".to_string(),
218        502 => "Bad Gateway".to_string(),
219        503 => "Service Unavailable".to_string(),
220        504 => "Gateway Timeout".to_string(),
221        0 => "Unknown Error".to_string(),
222        _ => format!("HTTP Error {status}"),
223    }
224}