statsig_rust/networking/
network_client.rs

1use chrono::Utc;
2
3use super::providers::get_network_provider;
4use super::{HttpMethod, NetworkProvider, RequestArgs};
5use crate::networking::net_utils::sanitize_url_for_logging;
6use crate::observability::ops_stats::{OpsStatsForInstance, OPS_STATS};
7use crate::observability::ErrorBoundaryEvent;
8use crate::sdk_diagnostics::marker::{ActionType, Marker, StepType};
9use crate::{log_d, log_error_to_statsig_and_console, log_i, log_w};
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::{Arc, Weak};
13use std::time::Duration;
14
15const RETRY_CODES: [u16; 8] = [408, 500, 502, 503, 504, 522, 524, 599];
16const SHUTDOWN_ERROR: &str = "Request was aborted because the client is shutting down";
17
18#[derive(PartialEq, Debug)]
19pub enum NetworkError {
20    ShutdownError,
21    RequestFailed,
22    RetriesExhausted,
23    SerializationError,
24}
25const TAG: &str = stringify!(NetworkClient);
26
27pub struct NetworkClient {
28    headers: HashMap<String, String>,
29    is_shutdown: Arc<AtomicBool>,
30    ops_stats: Arc<OpsStatsForInstance>,
31    net_provider: Weak<dyn NetworkProvider>,
32}
33
34impl NetworkClient {
35    #[must_use]
36    pub fn new(sdk_key: &str, headers: Option<HashMap<String, String>>) -> Self {
37        let net_provider = get_network_provider();
38
39        NetworkClient {
40            headers: headers.unwrap_or_default(),
41            is_shutdown: Arc::new(AtomicBool::new(false)),
42            net_provider,
43            ops_stats: OPS_STATS.get_for_instance(sdk_key),
44        }
45    }
46
47    pub fn shutdown(&self) {
48        self.is_shutdown.store(true, Ordering::SeqCst);
49    }
50
51    pub async fn get(&self, request_args: RequestArgs) -> Result<String, NetworkError> {
52        self.make_request(HttpMethod::GET, request_args).await
53    }
54
55    pub async fn post(
56        &self,
57        mut request_args: RequestArgs,
58        body: Option<Vec<u8>>,
59    ) -> Result<String, NetworkError> {
60        request_args.body = body;
61        self.make_request(HttpMethod::POST, request_args).await
62    }
63
64    async fn make_request(
65        &self,
66        method: HttpMethod,
67        mut request_args: RequestArgs,
68    ) -> Result<String, NetworkError> {
69        let is_shutdown = if let Some(is_shutdown) = &request_args.is_shutdown {
70            is_shutdown.clone()
71        } else {
72            self.is_shutdown.clone()
73        };
74
75        request_args.populate_headers(self.headers.clone());
76
77        let mut merged_headers = request_args.headers.unwrap_or_default();
78        if !self.headers.is_empty() {
79            merged_headers.extend(self.headers.clone());
80        }
81        merged_headers.insert(
82            "STATSIG-CLIENT-TIME".into(),
83            Utc::now().timestamp_millis().to_string(),
84        );
85        request_args.headers = Some(merged_headers);
86
87        let mut attempt = 0;
88
89        loop {
90            if let Some(key) = request_args.diagnostics_key {
91                self.ops_stats.add_marker(
92                    Marker::new(key, ActionType::Start, Some(StepType::NetworkRequest))
93                        .with_attempt(attempt)
94                        .with_url(request_args.url.clone()),
95                    None,
96                );
97            }
98            if is_shutdown.load(Ordering::SeqCst) {
99                log_i!(TAG, "{}", SHUTDOWN_ERROR);
100                return Err(NetworkError::ShutdownError);
101            }
102
103            let response = match self.net_provider.upgrade() {
104                Some(net_provider) => net_provider.send(&method, &request_args).await,
105                None => return Err(NetworkError::RequestFailed),
106            };
107
108            let sanitized_url = sanitize_url_for_logging(&request_args.url);
109            log_d!(
110                TAG,
111                "Response ({}): {}",
112                sanitized_url,
113                response.status_code
114            );
115
116            let status = response.status_code;
117            let sdk_region_str = response
118                .headers
119                .as_ref()
120                .and_then(|h| h.get("x-statsig-region"));
121            let success = (200..300).contains(&status);
122
123            let error_message = response
124                .error
125                .unwrap_or_else(|| get_error_message_for_status(status));
126
127            if let Some(key) = request_args.diagnostics_key {
128                let mut end_marker =
129                    Marker::new(key, ActionType::End, Some(StepType::NetworkRequest))
130                        .with_attempt(attempt)
131                        .with_url(request_args.url.clone())
132                        .with_status_code(status)
133                        .with_is_success(success)
134                        .with_sdk_region(sdk_region_str.map(|s| s.to_owned()));
135
136                let error_map = if !error_message.is_empty() {
137                    let mut map = HashMap::new();
138                    map.insert("name".to_string(), "NetworkError".to_string());
139                    map.insert("message".to_string(), error_message.clone());
140                    map.insert("code".to_string(), status.to_string());
141                    Some(map)
142                } else {
143                    None
144                };
145
146                if let Some(error_map) = error_map {
147                    end_marker = end_marker.with_error(error_map);
148                }
149
150                self.ops_stats.add_marker(end_marker, None);
151            }
152
153            if success {
154                return get_data_as_string(response.data);
155            }
156
157            if !RETRY_CODES.contains(&status) {
158                log_error_to_statsig_and_console!(
159                    &self.ops_stats,
160                    TAG,
161                    "status:{} message:{}",
162                    status,
163                    error_message
164                );
165                return Err(NetworkError::RequestFailed);
166            }
167
168            if attempt >= request_args.retries {
169                log_error_to_statsig_and_console!(
170                    &self.ops_stats,
171                    TAG,
172                    "Network error, retries exhausted: {} {}",
173                    status,
174                    error_message
175                );
176                return Err(NetworkError::RetriesExhausted);
177            }
178
179            attempt += 1;
180            let backoff_ms = 2_u64.pow(attempt) * 100;
181
182            log_w!(
183                TAG, "Network request failed with status code {} (attempt {}), will retry after {}ms...\n{}",
184                status,
185                attempt,
186                backoff_ms,
187                error_message
188            );
189
190            tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
191        }
192    }
193}
194
195fn get_error_message_for_status(status: u16) -> String {
196    if (200..300).contains(&status) {
197        return String::new();
198    }
199
200    match status {
201        400 => "Bad Request".to_string(),
202        401 => "Unauthorized".to_string(),
203        403 => "Forbidden".to_string(),
204        404 => "Not Found".to_string(),
205        405 => "Method Not Allowed".to_string(),
206        406 => "Not Acceptable".to_string(),
207        408 => "Request Timeout".to_string(),
208        500 => "Internal Server Error".to_string(),
209        502 => "Bad Gateway".to_string(),
210        503 => "Service Unavailable".to_string(),
211        504 => "Gateway Timeout".to_string(),
212        0 => "Unknown Error".to_string(),
213        _ => format!("HTTP Error {status}"),
214    }
215}
216
217fn get_data_as_string(data: Option<Vec<u8>>) -> Result<String, NetworkError> {
218    // todo: support compressed data
219    match data {
220        Some(data) => Ok(String::from_utf8(data).map_err(|_| NetworkError::SerializationError)?),
221        None => Err(NetworkError::RequestFailed),
222    }
223}