statsig_rust/networking/
network_client.rs

1use super::{HttpMethod, NetworkProvider, RequestArgs};
2use crate::networking::providers::Curl;
3use crate::observability::ops_stats::{OpsStatsForInstance, OPS_STATS};
4use crate::observability::ErrorBoundaryEvent;
5use crate::{log_error_to_statsig_and_console, log_i, log_w};
6use bytes::Bytes;
7use std::collections::HashMap;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use std::time::Duration;
11
12const RETRY_CODES: [u16; 8] = [408, 500, 502, 503, 504, 522, 524, 599];
13const SHUTDOWN_ERROR: &str = "Request was aborted because the client is shutting down";
14
15#[derive(PartialEq, Debug)]
16pub enum NetworkError {
17    ShutdownError,
18    RequestFailed,
19    RetriesExhausted,
20    SerializationError,
21}
22const TAG: &str = stringify!(NetworkClient);
23
24pub struct NetworkClient {
25    headers: HashMap<String, String>,
26    is_shutdown: Arc<AtomicBool>,
27    net_provider: Arc<dyn NetworkProvider>,
28    ops_stats: Arc<OpsStatsForInstance>,
29}
30
31impl NetworkClient {
32    #[must_use]
33    pub fn new(sdk_key: &str, headers: Option<HashMap<String, String>>) -> Self {
34        NetworkClient {
35            headers: headers.unwrap_or_default(),
36            is_shutdown: Arc::new(AtomicBool::new(false)),
37            net_provider: Arc::new(Curl::get_instance(sdk_key)),
38            ops_stats: OPS_STATS.get_for_instance(sdk_key),
39        }
40    }
41
42    pub fn shutdown(&self) {
43        self.is_shutdown.store(true, Ordering::SeqCst);
44    }
45
46    pub async fn get(&self, request_args: RequestArgs) -> Result<String, NetworkError> {
47        self.make_request(HttpMethod::GET, request_args).await
48    }
49
50    pub async fn post(
51        &self,
52        mut request_args: RequestArgs,
53        body: Option<Bytes>,
54    ) -> Result<String, NetworkError> {
55        request_args.body = body;
56        self.make_request(HttpMethod::POST, request_args).await
57    }
58
59    async fn make_request(
60        &self,
61        method: HttpMethod,
62        mut request_args: RequestArgs,
63    ) -> Result<String, NetworkError> {
64        let is_shutdown = if let Some(is_shutdown) = &request_args.is_shutdown {
65            is_shutdown.clone()
66        } else {
67            self.is_shutdown.clone()
68        };
69
70        if !self.headers.is_empty() {
71            let mut merged_headers = request_args.headers.unwrap_or_default();
72            merged_headers.extend(self.headers.clone());
73            request_args.headers = Some(merged_headers);
74        }
75
76        let mut attempt = 0;
77
78        loop {
79            if is_shutdown.load(Ordering::SeqCst) {
80                log_i!(TAG, "{}", SHUTDOWN_ERROR);
81                return Err(NetworkError::ShutdownError);
82            }
83
84            let response = self.net_provider.send(&method, &request_args).await;
85
86            let status = response.status_code;
87
88            if (200..300).contains(&status) {
89                return response.data.ok_or(NetworkError::RequestFailed);
90            }
91
92            let error_message = response
93                .error
94                .unwrap_or_else(|| get_error_message_for_status(status));
95
96            if !RETRY_CODES.contains(&status) {
97                log_error_to_statsig_and_console!(
98                    &self.ops_stats,
99                    TAG,
100                    "status:{} message:{}",
101                    status,
102                    error_message
103                );
104                return Err(NetworkError::RequestFailed);
105            }
106
107            if attempt >= request_args.retries {
108                log_error_to_statsig_and_console!(
109                    &self.ops_stats,
110                    TAG,
111                    "Network error, retries exhausted: {} {}",
112                    status,
113                    error_message
114                );
115                return Err(NetworkError::RetriesExhausted);
116            }
117
118            attempt += 1;
119            let backoff_ms = 2_u64.pow(attempt) * 100;
120
121            log_w!(
122                TAG, "Network request failed with status code {} (attempt {}), will retry after {}ms...\n{}",
123                status,
124                attempt,
125                backoff_ms,
126                error_message
127            );
128
129            tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
130        }
131    }
132}
133
134fn get_error_message_for_status(status: u16) -> String {
135    match status {
136        400 => "Bad Request".to_string(),
137        401 => "Unauthorized".to_string(),
138        403 => "Forbidden".to_string(),
139        404 => "Not Found".to_string(),
140        405 => "Method Not Allowed".to_string(),
141        406 => "Not Acceptable".to_string(),
142        408 => "Request Timeout".to_string(),
143        500 => "Internal Server Error".to_string(),
144        502 => "Bad Gateway".to_string(),
145        503 => "Service Unavailable".to_string(),
146        504 => "Gateway Timeout".to_string(),
147        0 => "Unknown Error".to_string(),
148        _ => format!("HTTP Error {status}"),
149    }
150}