Skip to main content

statsig_rust/networking/
network_client.rs

1use chrono::Utc;
2
3use super::network_error::NetworkError;
4use super::providers::get_network_provider;
5use super::{HttpMethod, NetworkProvider, RequestArgs, Response};
6use crate::networking::proxy_config::ProxyConfig;
7use crate::observability::observability_client_adapter::{MetricType, ObservabilityEvent};
8use crate::observability::ops_stats::{OpsStatsForInstance, OPS_STATS};
9use crate::observability::ErrorBoundaryEvent;
10use crate::sdk_diagnostics::marker::{ActionType, Marker, StepType};
11use crate::utils::{
12    get_loggable_sdk_key, is_version_segment, split_host_and_path, strip_query_and_fragment,
13};
14use crate::{log_d, log_i, log_w, StatsigOptions};
15use std::collections::HashMap;
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::sync::{Arc, Weak};
18use std::time::{Duration, Instant};
19
20const NON_RETRY_CODES: [u16; 6] = [
21    400, // Bad Request
22    403, // Forbidden
23    413, // Payload Too Large
24    405, // Method Not Allowed
25    429, // Too Many Requests
26    501, // Not Implemented
27];
28const SHUTDOWN_ERROR: &str = "Request was aborted because the client is shutting down";
29
30const MAX_REQUEST_PATH_LENGTH: usize = 64;
31const DOWNLOAD_CONFIG_SPECS_ENDPOINT: &str = "download_config_specs";
32const GET_ID_LISTS_ENDPOINT: &str = "get_id_lists";
33const DOWNLOAD_ID_LIST_FILE_ENDPOINT: &str = "download_id_list_file";
34const NETWORK_REQUEST_LATENCY_METRIC: &str = "network_request.latency";
35const REQUEST_PATH_TAG: &str = "request_path";
36const STATUS_CODE_TAG: &str = "status_code";
37const IS_SUCCESS_TAG: &str = "is_success";
38const SDK_KEY_TAG: &str = "sdk_key";
39const SOURCE_SERVICE_TAG: &str = "source_service";
40const ID_LIST_FILE_ID_TAG: &str = "id_list_file_id";
41const DELTAS_USED_TAG: &str = "deltas_used";
42
43const TAG: &str = stringify!(NetworkClient);
44
45pub struct NetworkClient {
46    headers: HashMap<String, String>,
47    is_shutdown: Arc<AtomicBool>,
48    ops_stats: Arc<OpsStatsForInstance>,
49    net_provider: Weak<dyn NetworkProvider>,
50    disable_network: bool,
51    proxy_config: Option<ProxyConfig>,
52    ca_cert_pem: Option<Vec<u8>>,
53    silent_on_network_failure: bool,
54    disable_file_streaming: bool,
55    loggable_sdk_key: String,
56}
57
58impl NetworkClient {
59    #[must_use]
60    pub fn new(
61        sdk_key: &str,
62        headers: Option<HashMap<String, String>>,
63        options: Option<&StatsigOptions>,
64    ) -> Self {
65        let net_provider = get_network_provider();
66        let (disable_network, proxy_config, ca_cert_pem) = options
67            .map(|opts| {
68                let ca_cert_pem = opts
69                    .proxy_config
70                    .as_ref()
71                    .and_then(|cfg| cfg.ca_cert_path.as_ref())
72                    .and_then(|path| {
73                        if path.is_empty() {
74                            return None;
75                        }
76                        match std::fs::read(path) {
77                            Ok(bytes) => Some(bytes),
78                            Err(e) => {
79                                log_w!(
80                                    TAG,
81                                    "Failed to read proxy_config.ca_cert_path '{}': {}",
82                                    path,
83                                    e
84                                );
85                                None
86                            }
87                        }
88                    });
89                (
90                    opts.disable_network.unwrap_or(false),
91                    opts.proxy_config.clone(),
92                    ca_cert_pem,
93                )
94            })
95            .unwrap_or((false, None, None));
96
97        NetworkClient {
98            headers: headers.unwrap_or_default(),
99            is_shutdown: Arc::new(AtomicBool::new(false)),
100            net_provider,
101            ops_stats: OPS_STATS.get_for_instance(sdk_key),
102            disable_network,
103            proxy_config,
104            ca_cert_pem,
105            silent_on_network_failure: false,
106            disable_file_streaming: options
107                .map(|opts| opts.disable_disk_access.unwrap_or(false))
108                .unwrap_or(false),
109            loggable_sdk_key: get_loggable_sdk_key(sdk_key),
110        }
111    }
112
113    pub fn shutdown(&self) {
114        self.is_shutdown.store(true, Ordering::SeqCst);
115    }
116
117    pub async fn get(&self, request_args: RequestArgs) -> Result<Response, NetworkError> {
118        self.make_request(HttpMethod::GET, request_args).await
119    }
120
121    pub async fn post(
122        &self,
123        mut request_args: RequestArgs,
124        body: Option<Vec<u8>>,
125    ) -> Result<Response, NetworkError> {
126        request_args.body = body;
127        self.make_request(HttpMethod::POST, request_args).await
128    }
129
130    async fn make_request(
131        &self,
132        method: HttpMethod,
133        mut request_args: RequestArgs,
134    ) -> Result<Response, NetworkError> {
135        let is_shutdown = if let Some(is_shutdown) = &request_args.is_shutdown {
136            is_shutdown.clone()
137        } else {
138            self.is_shutdown.clone()
139        };
140
141        if self.disable_network {
142            log_d!(TAG, "Network is disabled, not making requests");
143            return Err(NetworkError::DisableNetworkOn(request_args.url));
144        }
145
146        request_args.populate_headers(self.headers.clone());
147
148        if request_args.disable_file_streaming.is_none() {
149            request_args.disable_file_streaming = Some(self.disable_file_streaming);
150        }
151
152        if request_args.ca_cert_pem.is_none() {
153            request_args.ca_cert_pem = self.ca_cert_pem.clone();
154        }
155
156        let mut merged_headers = request_args.headers.unwrap_or_default();
157        if !self.headers.is_empty() {
158            merged_headers.extend(self.headers.clone());
159        }
160        merged_headers.insert(
161            "STATSIG-CLIENT-TIME".into(),
162            Utc::now().timestamp_millis().to_string(),
163        );
164        request_args.headers = Some(merged_headers);
165
166        // passing down proxy config through request args
167        if let Some(proxy_config) = &self.proxy_config {
168            request_args.proxy_config = Some(proxy_config.clone());
169        }
170        let mut attempt = 0;
171
172        loop {
173            if let Some(key) = request_args.diagnostics_key {
174                self.ops_stats.add_marker(
175                    Marker::new(key, ActionType::Start, Some(StepType::NetworkRequest))
176                        .with_attempt(attempt)
177                        .with_url(request_args.url.clone()),
178                    None,
179                );
180            }
181            if is_shutdown.load(Ordering::SeqCst) {
182                log_i!(TAG, "{}", SHUTDOWN_ERROR);
183                return Err(NetworkError::ShutdownError(request_args.url));
184            }
185
186            let request_start = Instant::now();
187            let mut response = match self.net_provider.upgrade() {
188                Some(net_provider) => net_provider.send(&method, &request_args).await,
189                None => {
190                    return Err(NetworkError::RequestFailed(
191                        request_args.url,
192                        None,
193                        "Failed to get a NetworkProvider instance".to_string(),
194                    ));
195                }
196            };
197
198            let status = response.status_code;
199            let error_message = response
200                .error
201                .clone()
202                .unwrap_or_else(|| get_error_message_for_status(status, response.data.as_mut()));
203
204            let content_type = response
205                .data
206                .as_ref()
207                .and_then(|data| data.get_header_ref("content-type"));
208
209            log_d!(
210                TAG,
211                "Response url({}) status({:?}) content-type({:?})",
212                &request_args.url,
213                response.status_code,
214                content_type
215            );
216
217            let sdk_region_str = response
218                .data
219                .as_ref()
220                .and_then(|data| data.get_header_ref("x-statsig-region").cloned());
221            let success = (200..300).contains(&status.unwrap_or(0));
222            let duration_ms = request_start.elapsed().as_millis() as f64;
223            self.log_network_request_latency_to_ob(&request_args, status, success, duration_ms);
224
225            if let Some(key) = request_args.diagnostics_key {
226                let mut end_marker =
227                    Marker::new(key, ActionType::End, Some(StepType::NetworkRequest))
228                        .with_attempt(attempt)
229                        .with_url(request_args.url.clone())
230                        .with_is_success(success)
231                        .with_content_type(content_type.cloned())
232                        .with_sdk_region(sdk_region_str.map(|s| s.to_owned()));
233
234                if let Some(status_code) = status {
235                    end_marker = end_marker.with_status_code(status_code);
236                }
237
238                let error_map = if !error_message.is_empty() {
239                    let mut map = HashMap::new();
240                    map.insert("name".to_string(), "NetworkError".to_string());
241                    map.insert("message".to_string(), error_message.clone());
242                    let status_string = match status {
243                        Some(code) => code.to_string(),
244                        None => "None".to_string(),
245                    };
246                    map.insert("code".to_string(), status_string);
247                    Some(map)
248                } else {
249                    None
250                };
251
252                if let Some(error_map) = error_map {
253                    end_marker = end_marker.with_error(error_map);
254                }
255
256                self.ops_stats.add_marker(end_marker, None);
257            }
258
259            if success {
260                return Ok(response);
261            }
262
263            if NON_RETRY_CODES.contains(&status.unwrap_or(0)) {
264                let error = NetworkError::RequestNotRetryable(
265                    request_args.url.clone(),
266                    status,
267                    error_message,
268                );
269                self.log_warning(&error, &request_args);
270                return Err(error);
271            }
272
273            if attempt >= request_args.retries {
274                let error = NetworkError::RetriesExhausted(
275                    request_args.url.clone(),
276                    status,
277                    attempt + 1,
278                    error_message,
279                );
280                self.log_warning(&error, &request_args);
281                return Err(error);
282            }
283
284            attempt += 1;
285            let backoff_ms = 2_u64.pow(attempt) * 100;
286
287            log_i!(
288                TAG, "Network request failed with status code {} (attempt {}/{}), will retry after {}ms...\n{}",
289                status.map_or("unknown".to_string(), |s| s.to_string()),
290                attempt,
291                request_args.retries + 1,
292                backoff_ms,
293                error_message
294            );
295
296            tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
297        }
298    }
299
300    pub fn mute_network_error_log(mut self) -> Self {
301        self.silent_on_network_failure = true;
302        self
303    }
304
305    // Logging helpers
306    fn log_warning(&self, error: &NetworkError, args: &RequestArgs) {
307        let exception = error.name();
308
309        log_w!(TAG, "{}", error);
310        if !self.silent_on_network_failure {
311            let dedupe_key = format!("{:?}", args.diagnostics_key);
312            self.ops_stats.log_error(ErrorBoundaryEvent {
313                tag: TAG.to_string(),
314                exception: exception.to_string(),
315                bypass_dedupe: false,
316                info: serde_json::to_string(error).unwrap_or_default(),
317                dedupe_key: Some(dedupe_key),
318                extra: None,
319            });
320        }
321    }
322
323    // ------------------------------------------------------------
324    // Observability Logging Helpers (OB only) - START
325    // ------------------------------------------------------------
326    fn log_network_request_latency_to_ob(
327        &self,
328        request_args: &RequestArgs,
329        status: Option<u16>,
330        success: bool,
331        duration_ms: f64,
332    ) {
333        let url = request_args.url.as_str();
334        if !should_log_network_request_latency(url) {
335            return;
336        }
337
338        let status_code = status
339            .map(|code| code.to_string())
340            .unwrap_or("none".to_string());
341        let tags = get_network_request_latency_tags(
342            request_args,
343            status_code,
344            success,
345            self.loggable_sdk_key.clone(),
346        );
347
348        self.ops_stats.log(ObservabilityEvent::new_event(
349            MetricType::Dist,
350            NETWORK_REQUEST_LATENCY_METRIC.to_string(),
351            duration_ms,
352            Some(tags),
353        ));
354    }
355}
356
357fn get_network_request_latency_tags(
358    request_args: &RequestArgs,
359    status_code: String,
360    success: bool,
361    loggable_sdk_key: String,
362) -> HashMap<String, String> {
363    let (source_service, request_path) = get_source_service_and_request_path(&request_args.url);
364    let mut tags = HashMap::from([
365        (REQUEST_PATH_TAG.to_string(), request_path),
366        (STATUS_CODE_TAG.to_string(), status_code),
367        (IS_SUCCESS_TAG.to_string(), success.to_string()),
368        (SDK_KEY_TAG.to_string(), loggable_sdk_key),
369        (SOURCE_SERVICE_TAG.to_string(), source_service),
370        (
371            DELTAS_USED_TAG.to_string(),
372            request_args.deltas_enabled.to_string(),
373        ),
374    ]);
375    if let Some(id_list_file_id) = request_args
376        .id_list_file_id
377        .as_ref()
378        .filter(|id| !id.is_empty())
379    {
380        tags.insert(ID_LIST_FILE_ID_TAG.to_string(), id_list_file_id.clone());
381    }
382
383    tags
384}
385
386fn is_latency_loggable_endpoint(endpoint: &str) -> bool {
387    endpoint == DOWNLOAD_CONFIG_SPECS_ENDPOINT
388        || endpoint == GET_ID_LISTS_ENDPOINT
389        || endpoint == DOWNLOAD_ID_LIST_FILE_ENDPOINT
390}
391
392fn get_version_and_endpoint_for_latency<'a>(
393    segments: &'a [&'a str],
394) -> Option<(usize, &'a str, &'a str)> {
395    // Find a known endpoint pattern, then verify the segment right before it is `/v{number}`.
396    segments
397        .iter()
398        .enumerate()
399        .find_map(|(endpoint_index, endpoint_segment)| {
400            if !is_latency_loggable_endpoint(endpoint_segment) || endpoint_index == 0 {
401                return None;
402            }
403
404            let version_index = endpoint_index - 1;
405            let version_segment = segments[version_index];
406            is_version_segment(version_segment).then_some((
407                version_index,
408                version_segment,
409                *endpoint_segment,
410            ))
411        })
412}
413
414fn should_log_network_request_latency(url: &str) -> bool {
415    let (_, raw_path) = split_host_and_path(url);
416    let normalized_path = strip_query_and_fragment(raw_path).trim_start_matches('/');
417    let segments: Vec<&str> = normalized_path
418        .split('/')
419        .filter(|segment| !segment.is_empty())
420        .collect();
421
422    get_version_and_endpoint_for_latency(&segments).is_some()
423}
424
425fn with_host_prefix(host_prefix: &str, path: &str) -> String {
426    if host_prefix.is_empty() {
427        path.to_string()
428    } else {
429        format!("{host_prefix}{path}")
430    }
431}
432
433fn get_source_service_and_request_path(url: &str) -> (String, String) {
434    let (host_prefix, raw_path) = split_host_and_path(url);
435    let normalized_path = strip_query_and_fragment(raw_path).trim_start_matches('/');
436    let segments: Vec<&str> = normalized_path
437        .split('/')
438        .filter(|segment| !segment.is_empty())
439        .collect();
440
441    if let Some((version_index, version_segment, endpoint_segment)) =
442        get_version_and_endpoint_for_latency(&segments)
443    {
444        let request_path = format!("/{version_segment}/{endpoint_segment}");
445        let source_service_suffix = segments[..version_index].join("/");
446        let source_service = with_host_prefix(&host_prefix, &source_service_suffix)
447            .trim_end_matches('/')
448            .to_string();
449        return (source_service, request_path);
450    }
451
452    let fallback_request_path: String = normalized_path
453        .chars()
454        .take(MAX_REQUEST_PATH_LENGTH)
455        .collect();
456    let request_path = if fallback_request_path.is_empty() {
457        "/".to_string()
458    } else {
459        format!("/{fallback_request_path}")
460    };
461    let source_service = host_prefix.trim_end_matches('/').to_string();
462    (source_service, request_path)
463}
464
465#[cfg(test)]
466fn get_request_path(url: &str) -> String {
467    get_source_service_and_request_path(url).1
468}
469
470// ------------------------------------------------------------
471// Observability Logging Helpers (OB only) - END
472// ------------------------------------------------------------
473
474fn get_error_message_for_status(
475    status: Option<u16>,
476    data: Option<&mut super::ResponseData>,
477) -> String {
478    if (200..300).contains(&status.unwrap_or(0)) {
479        return String::new();
480    }
481
482    let mut message = String::new();
483    if let Some(data) = data {
484        let lossy_str = data.read_to_string().unwrap_or_default();
485        if lossy_str.is_ascii() {
486            message = lossy_str.to_string();
487        }
488    }
489
490    let status_value = match status {
491        Some(code) => code,
492        None => return format!("HTTP Error None: {message}"),
493    };
494
495    let generic_message = match status_value {
496        400 => "Bad Request",
497        401 => "Unauthorized",
498        403 => "Forbidden",
499        404 => "Not Found",
500        405 => "Method Not Allowed",
501        406 => "Not Acceptable",
502        408 => "Request Timeout",
503        500 => "Internal Server Error",
504        502 => "Bad Gateway",
505        503 => "Service Unavailable",
506        504 => "Gateway Timeout",
507        0 => "Unknown Error",
508        _ => return format!("HTTP Error {status_value}: {message}"),
509    };
510
511    if message.is_empty() {
512        return generic_message.to_string();
513    }
514
515    format!("{generic_message}: {message}")
516}
517
518#[cfg(test)]
519mod tests {
520    use super::{
521        get_network_request_latency_tags, get_request_path, get_source_service_and_request_path,
522        should_log_network_request_latency, DELTAS_USED_TAG, ID_LIST_FILE_ID_TAG, REQUEST_PATH_TAG,
523    };
524    use crate::networking::RequestArgs;
525
526    #[test]
527    fn test_get_request_path_with_sample_urls() {
528        assert_eq!(
529            get_request_path("https://api.statsigcdn.com/v1/download_id_list_file/3wHgh0FhoQH0p"),
530            "/v1/download_id_list_file"
531        );
532        assert_eq!(
533            get_request_path("https://api.statsigcdn.com/v1/download_id_list_file/Q9mXcz7L1P43tRb8kV2dHyw%2FM6nJf0Ae5uTqsrC4Gp9KZ?foo=bar"),
534            "/v1/download_id_list_file"
535        );
536        assert_eq!(
537            get_request_path("https://api.statsig.com/v1/get_id_lists/secret-abcdef"),
538            "/v1/get_id_lists"
539        );
540        assert_eq!(
541            get_request_path("https://api.statsigcdn.com/v2/download_config_specs/secret-123456"),
542            "/v2/download_config_specs"
543        );
544    }
545
546    #[test]
547    fn test_should_log_network_request_latency_for_supported_endpoints() {
548        assert!(!should_log_network_request_latency(
549            "https://api.statsig.com/v1/log_event"
550        ));
551        assert!(!should_log_network_request_latency(
552            "https://api.statsig.com/v1/sdk_exception"
553        ));
554        assert!(should_log_network_request_latency(
555            "https://api.statsig.com/v1/get_id_lists/secret-abcdef"
556        ));
557        assert!(should_log_network_request_latency(
558            "https://api.statsigcdn.com/v2/download_config_specs/secret-123456"
559        ));
560        assert!(should_log_network_request_latency(
561            "https://api.statsigcdn.com/v1/download_id_list_file/3wHgh0FhoQH0p"
562        ));
563    }
564
565    #[test]
566    fn test_get_source_service_and_request_path() {
567        let (source_service, request_path) = get_source_service_and_request_path(
568            "http://127.0.0.1:12345/mock-uuid/v2/download_config_specs/secret-key.json?x=1",
569        );
570        assert_eq!(source_service, "http://127.0.0.1:12345/mock-uuid");
571        assert_eq!(request_path, "/v2/download_config_specs");
572    }
573
574    #[test]
575    fn test_network_latency_tags_include_id_list_file_id_only_when_present() {
576        let mut request_args = RequestArgs {
577            url: "https://api.statsigcdn.com/v1/download_id_list_file/file-123".to_string(),
578            id_list_file_id: Some("file-123".to_string()),
579            ..RequestArgs::new()
580        };
581
582        let tags = get_network_request_latency_tags(
583            &request_args,
584            "200".to_string(),
585            true,
586            "client-key-123".to_string(),
587        );
588        assert_eq!(tags.get(ID_LIST_FILE_ID_TAG), Some(&"file-123".to_string()));
589        assert_eq!(tags.get(DELTAS_USED_TAG), Some(&"false".to_string()));
590        assert_eq!(
591            tags.get(REQUEST_PATH_TAG),
592            Some(&"/v1/download_id_list_file".to_string())
593        );
594
595        request_args.id_list_file_id = Some(String::new());
596        request_args.deltas_enabled = true;
597        let tags_without_id = get_network_request_latency_tags(
598            &request_args,
599            "200".to_string(),
600            true,
601            "client-key-123".to_string(),
602        );
603        assert!(!tags_without_id.contains_key(ID_LIST_FILE_ID_TAG));
604        assert_eq!(
605            tags_without_id.get(DELTAS_USED_TAG),
606            Some(&"true".to_string())
607        );
608    }
609}