Skip to main content

statsig_rust/networking/
network_client.rs

1use chrono::Utc;
2
3use super::network_error::NetworkError;
4use super::providers::get_network_provider;
5use super::{HttpMethod, NetworkProvider, RequestArgs, Response};
6use crate::networking::proxy_config::ProxyConfig;
7use crate::observability::observability_client_adapter::{MetricType, ObservabilityEvent};
8use crate::observability::ops_stats::{OpsStatsForInstance, OPS_STATS};
9use crate::observability::ErrorBoundaryEvent;
10use crate::sdk_diagnostics::marker::{ActionType, Marker, StepType};
11use crate::utils::{is_version_segment, split_host_and_path, strip_query_and_fragment};
12use crate::{log_d, log_i, log_w, StatsigOptions};
13use std::collections::HashMap;
14use std::sync::atomic::{AtomicBool, Ordering};
15use std::sync::{Arc, Weak};
16use std::time::{Duration, Instant};
17
18const NON_RETRY_CODES: [u16; 6] = [
19    400, // Bad Request
20    403, // Forbidden
21    413, // Payload Too Large
22    405, // Method Not Allowed
23    429, // Too Many Requests
24    501, // Not Implemented
25];
26const SHUTDOWN_ERROR: &str = "Request was aborted because the client is shutting down";
27
28const MAX_REQUEST_PATH_LENGTH: usize = 64;
29const LOGGABLE_KEY_PREFIX_LENGTH: usize = 13;
30const DOWNLOAD_CONFIG_SPECS_ENDPOINT: &str = "download_config_specs";
31const GET_ID_LISTS_ENDPOINT: &str = "get_id_lists";
32const DOWNLOAD_ID_LIST_FILE_ENDPOINT: &str = "download_id_list_file";
33const NETWORK_REQUEST_LATENCY_METRIC: &str = "network_request.latency";
34const REQUEST_PATH_TAG: &str = "request_path";
35const STATUS_CODE_TAG: &str = "status_code";
36const IS_SUCCESS_TAG: &str = "is_success";
37const SDK_KEY_TAG: &str = "sdk_key";
38const SOURCE_SERVICE_TAG: &str = "source_service";
39const ID_LIST_FILE_ID_TAG: &str = "id_list_file_id";
40
41const TAG: &str = stringify!(NetworkClient);
42
43pub struct NetworkClient {
44    headers: HashMap<String, String>,
45    is_shutdown: Arc<AtomicBool>,
46    ops_stats: Arc<OpsStatsForInstance>,
47    net_provider: Weak<dyn NetworkProvider>,
48    disable_network: bool,
49    proxy_config: Option<ProxyConfig>,
50    silent_on_network_failure: bool,
51    disable_file_streaming: bool,
52    loggable_sdk_key: String,
53}
54
55impl NetworkClient {
56    #[must_use]
57    pub fn new(
58        sdk_key: &str,
59        headers: Option<HashMap<String, String>>,
60        options: Option<&StatsigOptions>,
61    ) -> Self {
62        let net_provider = get_network_provider();
63        let (disable_network, proxy_config) = options
64            .map(|opts| {
65                (
66                    opts.disable_network.unwrap_or(false),
67                    opts.proxy_config.clone(),
68                )
69            })
70            .unwrap_or((false, None));
71
72        NetworkClient {
73            headers: headers.unwrap_or_default(),
74            is_shutdown: Arc::new(AtomicBool::new(false)),
75            net_provider,
76            ops_stats: OPS_STATS.get_for_instance(sdk_key),
77            disable_network,
78            proxy_config,
79            silent_on_network_failure: false,
80            disable_file_streaming: options
81                .map(|opts| opts.disable_disk_access.unwrap_or(false))
82                .unwrap_or(false),
83            loggable_sdk_key: get_loggable_sdk_key(sdk_key),
84        }
85    }
86
87    pub fn shutdown(&self) {
88        self.is_shutdown.store(true, Ordering::SeqCst);
89    }
90
91    pub async fn get(&self, request_args: RequestArgs) -> Result<Response, NetworkError> {
92        self.make_request(HttpMethod::GET, request_args).await
93    }
94
95    pub async fn post(
96        &self,
97        mut request_args: RequestArgs,
98        body: Option<Vec<u8>>,
99    ) -> Result<Response, NetworkError> {
100        request_args.body = body;
101        self.make_request(HttpMethod::POST, request_args).await
102    }
103
104    async fn make_request(
105        &self,
106        method: HttpMethod,
107        mut request_args: RequestArgs,
108    ) -> Result<Response, NetworkError> {
109        let is_shutdown = if let Some(is_shutdown) = &request_args.is_shutdown {
110            is_shutdown.clone()
111        } else {
112            self.is_shutdown.clone()
113        };
114
115        if self.disable_network {
116            log_d!(TAG, "Network is disabled, not making requests");
117            return Err(NetworkError::DisableNetworkOn(request_args.url));
118        }
119
120        request_args.populate_headers(self.headers.clone());
121
122        if request_args.disable_file_streaming.is_none() {
123            request_args.disable_file_streaming = Some(self.disable_file_streaming);
124        }
125
126        let mut merged_headers = request_args.headers.unwrap_or_default();
127        if !self.headers.is_empty() {
128            merged_headers.extend(self.headers.clone());
129        }
130        merged_headers.insert(
131            "STATSIG-CLIENT-TIME".into(),
132            Utc::now().timestamp_millis().to_string(),
133        );
134        request_args.headers = Some(merged_headers);
135
136        // passing down proxy config through request args
137        if let Some(proxy_config) = &self.proxy_config {
138            request_args.proxy_config = Some(proxy_config.clone());
139        }
140        let mut attempt = 0;
141
142        loop {
143            if let Some(key) = request_args.diagnostics_key {
144                self.ops_stats.add_marker(
145                    Marker::new(key, ActionType::Start, Some(StepType::NetworkRequest))
146                        .with_attempt(attempt)
147                        .with_url(request_args.url.clone()),
148                    None,
149                );
150            }
151            if is_shutdown.load(Ordering::SeqCst) {
152                log_i!(TAG, "{}", SHUTDOWN_ERROR);
153                return Err(NetworkError::ShutdownError(request_args.url));
154            }
155
156            let request_start = Instant::now();
157            let mut response = match self.net_provider.upgrade() {
158                Some(net_provider) => net_provider.send(&method, &request_args).await,
159                None => {
160                    return Err(NetworkError::RequestFailed(
161                        request_args.url,
162                        None,
163                        "Failed to get a NetworkProvider instance".to_string(),
164                    ));
165                }
166            };
167
168            let status = response.status_code;
169            let error_message = response
170                .error
171                .clone()
172                .unwrap_or_else(|| get_error_message_for_status(status, response.data.as_mut()));
173
174            let content_type = response
175                .data
176                .as_ref()
177                .and_then(|data| data.get_header_ref("content-type"));
178
179            log_d!(
180                TAG,
181                "Response url({}) status({:?}) content-type({:?})",
182                &request_args.url,
183                response.status_code,
184                content_type
185            );
186
187            let sdk_region_str = response
188                .data
189                .as_ref()
190                .and_then(|data| data.get_header_ref("x-statsig-region").cloned());
191            let success = (200..300).contains(&status.unwrap_or(0));
192            let duration_ms = request_start.elapsed().as_millis() as f64;
193            self.log_network_request_latency_to_ob(&request_args, status, success, duration_ms);
194
195            if let Some(key) = request_args.diagnostics_key {
196                let mut end_marker =
197                    Marker::new(key, ActionType::End, Some(StepType::NetworkRequest))
198                        .with_attempt(attempt)
199                        .with_url(request_args.url.clone())
200                        .with_is_success(success)
201                        .with_content_type(content_type.cloned())
202                        .with_sdk_region(sdk_region_str.map(|s| s.to_owned()));
203
204                if let Some(status_code) = status {
205                    end_marker = end_marker.with_status_code(status_code);
206                }
207
208                let error_map = if !error_message.is_empty() {
209                    let mut map = HashMap::new();
210                    map.insert("name".to_string(), "NetworkError".to_string());
211                    map.insert("message".to_string(), error_message.clone());
212                    let status_string = match status {
213                        Some(code) => code.to_string(),
214                        None => "None".to_string(),
215                    };
216                    map.insert("code".to_string(), status_string);
217                    Some(map)
218                } else {
219                    None
220                };
221
222                if let Some(error_map) = error_map {
223                    end_marker = end_marker.with_error(error_map);
224                }
225
226                self.ops_stats.add_marker(end_marker, None);
227            }
228
229            if success {
230                return Ok(response);
231            }
232
233            if NON_RETRY_CODES.contains(&status.unwrap_or(0)) {
234                let error = NetworkError::RequestNotRetryable(
235                    request_args.url.clone(),
236                    status,
237                    error_message,
238                );
239                self.log_warning(&error, &request_args);
240                return Err(error);
241            }
242
243            if attempt >= request_args.retries {
244                let error = NetworkError::RetriesExhausted(
245                    request_args.url.clone(),
246                    status,
247                    attempt + 1,
248                    error_message,
249                );
250                self.log_warning(&error, &request_args);
251                return Err(error);
252            }
253
254            attempt += 1;
255            let backoff_ms = 2_u64.pow(attempt) * 100;
256
257            log_i!(
258                TAG, "Network request failed with status code {} (attempt {}/{}), will retry after {}ms...\n{}",
259                status.map_or("unknown".to_string(), |s| s.to_string()),
260                attempt,
261                request_args.retries + 1,
262                backoff_ms,
263                error_message
264            );
265
266            tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
267        }
268    }
269
270    pub fn mute_network_error_log(mut self) -> Self {
271        self.silent_on_network_failure = true;
272        self
273    }
274
275    // Logging helpers
276    fn log_warning(&self, error: &NetworkError, args: &RequestArgs) {
277        let exception = error.name();
278
279        log_w!(TAG, "{}", error);
280        if !self.silent_on_network_failure {
281            let dedupe_key = format!("{:?}", args.diagnostics_key);
282            self.ops_stats.log_error(ErrorBoundaryEvent {
283                tag: TAG.to_string(),
284                exception: exception.to_string(),
285                bypass_dedupe: false,
286                info: serde_json::to_string(error).unwrap_or_default(),
287                dedupe_key: Some(dedupe_key),
288                extra: None,
289            });
290        }
291    }
292
293    // ------------------------------------------------------------
294    // Observability Logging Helpers (OB only) - START
295    // ------------------------------------------------------------
296    fn log_network_request_latency_to_ob(
297        &self,
298        request_args: &RequestArgs,
299        status: Option<u16>,
300        success: bool,
301        duration_ms: f64,
302    ) {
303        let url = request_args.url.as_str();
304        if !should_log_network_request_latency(url) {
305            return;
306        }
307
308        let status_code = status
309            .map(|code| code.to_string())
310            .unwrap_or("none".to_string());
311        let tags = get_network_request_latency_tags(
312            request_args,
313            status_code,
314            success,
315            self.loggable_sdk_key.clone(),
316        );
317
318        self.ops_stats.log(ObservabilityEvent::new_event(
319            MetricType::Dist,
320            NETWORK_REQUEST_LATENCY_METRIC.to_string(),
321            duration_ms,
322            Some(tags),
323        ));
324    }
325}
326
327fn get_network_request_latency_tags(
328    request_args: &RequestArgs,
329    status_code: String,
330    success: bool,
331    loggable_sdk_key: String,
332) -> HashMap<String, String> {
333    let (source_service, request_path) = get_source_service_and_request_path(&request_args.url);
334    let mut tags = HashMap::from([
335        (REQUEST_PATH_TAG.to_string(), request_path),
336        (STATUS_CODE_TAG.to_string(), status_code),
337        (IS_SUCCESS_TAG.to_string(), success.to_string()),
338        (SDK_KEY_TAG.to_string(), loggable_sdk_key),
339        (SOURCE_SERVICE_TAG.to_string(), source_service),
340    ]);
341    if let Some(id_list_file_id) = request_args
342        .id_list_file_id
343        .as_ref()
344        .filter(|id| !id.is_empty())
345    {
346        tags.insert(ID_LIST_FILE_ID_TAG.to_string(), id_list_file_id.clone());
347    }
348
349    tags
350}
351
352fn get_loggable_sdk_key(sdk_key: &str) -> String {
353    sdk_key.chars().take(LOGGABLE_KEY_PREFIX_LENGTH).collect()
354}
355
356fn is_latency_loggable_endpoint(endpoint: &str) -> bool {
357    endpoint == DOWNLOAD_CONFIG_SPECS_ENDPOINT
358        || endpoint == GET_ID_LISTS_ENDPOINT
359        || endpoint == DOWNLOAD_ID_LIST_FILE_ENDPOINT
360}
361
362fn get_version_and_endpoint_for_latency<'a>(
363    segments: &'a [&'a str],
364) -> Option<(usize, &'a str, &'a str)> {
365    // Find a known endpoint pattern, then verify the segment right before it is `/v{number}`.
366    segments
367        .iter()
368        .enumerate()
369        .find_map(|(endpoint_index, endpoint_segment)| {
370            if !is_latency_loggable_endpoint(endpoint_segment) || endpoint_index == 0 {
371                return None;
372            }
373
374            let version_index = endpoint_index - 1;
375            let version_segment = segments[version_index];
376            is_version_segment(version_segment).then_some((
377                version_index,
378                version_segment,
379                *endpoint_segment,
380            ))
381        })
382}
383
384fn should_log_network_request_latency(url: &str) -> bool {
385    let (_, raw_path) = split_host_and_path(url);
386    let normalized_path = strip_query_and_fragment(raw_path).trim_start_matches('/');
387    let segments: Vec<&str> = normalized_path
388        .split('/')
389        .filter(|segment| !segment.is_empty())
390        .collect();
391
392    get_version_and_endpoint_for_latency(&segments).is_some()
393}
394
395fn with_host_prefix(host_prefix: &str, path: &str) -> String {
396    if host_prefix.is_empty() {
397        path.to_string()
398    } else {
399        format!("{host_prefix}{path}")
400    }
401}
402
403fn get_source_service_and_request_path(url: &str) -> (String, String) {
404    let (host_prefix, raw_path) = split_host_and_path(url);
405    let normalized_path = strip_query_and_fragment(raw_path).trim_start_matches('/');
406    let segments: Vec<&str> = normalized_path
407        .split('/')
408        .filter(|segment| !segment.is_empty())
409        .collect();
410
411    if let Some((version_index, version_segment, endpoint_segment)) =
412        get_version_and_endpoint_for_latency(&segments)
413    {
414        let request_path = format!("/{version_segment}/{endpoint_segment}");
415        let source_service_suffix = segments[..version_index].join("/");
416        let source_service = with_host_prefix(&host_prefix, &source_service_suffix)
417            .trim_end_matches('/')
418            .to_string();
419        return (source_service, request_path);
420    }
421
422    let fallback_request_path: String = normalized_path
423        .chars()
424        .take(MAX_REQUEST_PATH_LENGTH)
425        .collect();
426    let request_path = if fallback_request_path.is_empty() {
427        "/".to_string()
428    } else {
429        format!("/{fallback_request_path}")
430    };
431    let source_service = host_prefix.trim_end_matches('/').to_string();
432    (source_service, request_path)
433}
434
435#[cfg(test)]
436fn get_request_path(url: &str) -> String {
437    get_source_service_and_request_path(url).1
438}
439
440// ------------------------------------------------------------
441// Observability Logging Helpers (OB only) - END
442// ------------------------------------------------------------
443
444fn get_error_message_for_status(
445    status: Option<u16>,
446    data: Option<&mut super::ResponseData>,
447) -> String {
448    if (200..300).contains(&status.unwrap_or(0)) {
449        return String::new();
450    }
451
452    let mut message = String::new();
453    if let Some(data) = data {
454        let lossy_str = data.read_to_string().unwrap_or_default();
455        if lossy_str.is_ascii() {
456            message = lossy_str.to_string();
457        }
458    }
459
460    let status_value = match status {
461        Some(code) => code,
462        None => return format!("HTTP Error None: {message}"),
463    };
464
465    let generic_message = match status_value {
466        400 => "Bad Request",
467        401 => "Unauthorized",
468        403 => "Forbidden",
469        404 => "Not Found",
470        405 => "Method Not Allowed",
471        406 => "Not Acceptable",
472        408 => "Request Timeout",
473        500 => "Internal Server Error",
474        502 => "Bad Gateway",
475        503 => "Service Unavailable",
476        504 => "Gateway Timeout",
477        0 => "Unknown Error",
478        _ => return format!("HTTP Error {status_value}: {message}"),
479    };
480
481    if message.is_empty() {
482        return generic_message.to_string();
483    }
484
485    format!("{generic_message}: {message}")
486}
487
488#[cfg(test)]
489mod tests {
490    use super::{
491        get_network_request_latency_tags, get_request_path, get_source_service_and_request_path,
492        should_log_network_request_latency, ID_LIST_FILE_ID_TAG, REQUEST_PATH_TAG,
493    };
494    use crate::networking::RequestArgs;
495
496    #[test]
497    fn test_get_request_path_with_sample_urls() {
498        assert_eq!(
499            get_request_path("https://api.statsigcdn.com/v1/download_id_list_file/3wHgh0FhoQH0p"),
500            "/v1/download_id_list_file"
501        );
502        assert_eq!(
503            get_request_path("https://api.statsigcdn.com/v1/download_id_list_file/Q9mXcz7L1P43tRb8kV2dHyw%2FM6nJf0Ae5uTqsrC4Gp9KZ?foo=bar"),
504            "/v1/download_id_list_file"
505        );
506        assert_eq!(
507            get_request_path("https://api.statsig.com/v1/get_id_lists/secret-abcdef"),
508            "/v1/get_id_lists"
509        );
510        assert_eq!(
511            get_request_path("https://api.statsigcdn.com/v2/download_config_specs/secret-123456"),
512            "/v2/download_config_specs"
513        );
514    }
515
516    #[test]
517    fn test_should_log_network_request_latency_for_supported_endpoints() {
518        assert!(!should_log_network_request_latency(
519            "https://api.statsig.com/v1/log_event"
520        ));
521        assert!(!should_log_network_request_latency(
522            "https://api.statsig.com/v1/sdk_exception"
523        ));
524        assert!(should_log_network_request_latency(
525            "https://api.statsig.com/v1/get_id_lists/secret-abcdef"
526        ));
527        assert!(should_log_network_request_latency(
528            "https://api.statsigcdn.com/v2/download_config_specs/secret-123456"
529        ));
530        assert!(should_log_network_request_latency(
531            "https://api.statsigcdn.com/v1/download_id_list_file/3wHgh0FhoQH0p"
532        ));
533    }
534
535    #[test]
536    fn test_get_source_service_and_request_path() {
537        let (source_service, request_path) = get_source_service_and_request_path(
538            "http://127.0.0.1:12345/mock-uuid/v2/download_config_specs/secret-key.json?x=1",
539        );
540        assert_eq!(source_service, "http://127.0.0.1:12345/mock-uuid");
541        assert_eq!(request_path, "/v2/download_config_specs");
542    }
543
544    #[test]
545    fn test_network_latency_tags_include_id_list_file_id_only_when_present() {
546        let mut request_args = RequestArgs {
547            url: "https://api.statsigcdn.com/v1/download_id_list_file/file-123".to_string(),
548            id_list_file_id: Some("file-123".to_string()),
549            ..RequestArgs::new()
550        };
551
552        let tags = get_network_request_latency_tags(
553            &request_args,
554            "200".to_string(),
555            true,
556            "client-key-123".to_string(),
557        );
558        assert_eq!(tags.get(ID_LIST_FILE_ID_TAG), Some(&"file-123".to_string()));
559        assert_eq!(
560            tags.get(REQUEST_PATH_TAG),
561            Some(&"/v1/download_id_list_file".to_string())
562        );
563
564        request_args.id_list_file_id = Some(String::new());
565        let tags_without_id = get_network_request_latency_tags(
566            &request_args,
567            "200".to_string(),
568            true,
569            "client-key-123".to_string(),
570        );
571        assert!(!tags_without_id.contains_key(ID_LIST_FILE_ID_TAG));
572    }
573}