Skip to main content

statsig_rust/networking/
network_client.rs

1use chrono::Utc;
2
3use super::network_error::NetworkError;
4use super::providers::get_network_provider;
5use super::{HttpMethod, NetworkProvider, RequestArgs, Response};
6use crate::networking::proxy_config::ProxyConfig;
7use crate::observability::observability_client_adapter::{MetricType, ObservabilityEvent};
8use crate::observability::ops_stats::{OpsStatsForInstance, OPS_STATS};
9use crate::observability::ErrorBoundaryEvent;
10use crate::sdk_diagnostics::marker::{ActionType, Marker, StepType};
11use crate::utils::{
12    get_loggable_sdk_key, is_version_segment, split_host_and_path, strip_query_and_fragment,
13};
14use crate::{log_d, log_i, log_w, StatsigOptions};
15use std::collections::HashMap;
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::sync::{Arc, Weak};
18use std::time::{Duration, Instant};
19
20const NON_RETRY_CODES: [u16; 6] = [
21    400, // Bad Request
22    403, // Forbidden
23    413, // Payload Too Large
24    405, // Method Not Allowed
25    429, // Too Many Requests
26    501, // Not Implemented
27];
28const SHUTDOWN_ERROR: &str = "Request was aborted because the client is shutting down";
29
30const MAX_REQUEST_PATH_LENGTH: usize = 64;
31const DOWNLOAD_CONFIG_SPECS_ENDPOINT: &str = "download_config_specs";
32const GET_ID_LISTS_ENDPOINT: &str = "get_id_lists";
33const DOWNLOAD_ID_LIST_FILE_ENDPOINT: &str = "download_id_list_file";
34const NETWORK_REQUEST_LATENCY_METRIC: &str = "network_request.latency";
35const REQUEST_PATH_TAG: &str = "request_path";
36const STATUS_CODE_TAG: &str = "status_code";
37const IS_SUCCESS_TAG: &str = "is_success";
38const SDK_KEY_TAG: &str = "sdk_key";
39const SOURCE_SERVICE_TAG: &str = "source_service";
40const ID_LIST_FILE_ID_TAG: &str = "id_list_file_id";
41
42const TAG: &str = stringify!(NetworkClient);
43
44pub struct NetworkClient {
45    headers: HashMap<String, String>,
46    is_shutdown: Arc<AtomicBool>,
47    ops_stats: Arc<OpsStatsForInstance>,
48    net_provider: Weak<dyn NetworkProvider>,
49    disable_network: bool,
50    proxy_config: Option<ProxyConfig>,
51    ca_cert_pem: Option<Vec<u8>>,
52    silent_on_network_failure: bool,
53    disable_file_streaming: bool,
54    loggable_sdk_key: String,
55}
56
57impl NetworkClient {
58    #[must_use]
59    pub fn new(
60        sdk_key: &str,
61        headers: Option<HashMap<String, String>>,
62        options: Option<&StatsigOptions>,
63    ) -> Self {
64        let net_provider = get_network_provider();
65        let (disable_network, proxy_config, ca_cert_pem) = options
66            .map(|opts| {
67                let ca_cert_pem = opts
68                    .proxy_config
69                    .as_ref()
70                    .and_then(|cfg| cfg.ca_cert_path.as_ref())
71                    .and_then(|path| {
72                        if path.is_empty() {
73                            return None;
74                        }
75                        match std::fs::read(path) {
76                            Ok(bytes) => Some(bytes),
77                            Err(e) => {
78                                log_w!(
79                                    TAG,
80                                    "Failed to read proxy_config.ca_cert_path '{}': {}",
81                                    path,
82                                    e
83                                );
84                                None
85                            }
86                        }
87                    });
88                (
89                    opts.disable_network.unwrap_or(false),
90                    opts.proxy_config.clone(),
91                    ca_cert_pem,
92                )
93            })
94            .unwrap_or((false, None, None));
95
96        NetworkClient {
97            headers: headers.unwrap_or_default(),
98            is_shutdown: Arc::new(AtomicBool::new(false)),
99            net_provider,
100            ops_stats: OPS_STATS.get_for_instance(sdk_key),
101            disable_network,
102            proxy_config,
103            ca_cert_pem,
104            silent_on_network_failure: false,
105            disable_file_streaming: options
106                .map(|opts| opts.disable_disk_access.unwrap_or(false))
107                .unwrap_or(false),
108            loggable_sdk_key: get_loggable_sdk_key(sdk_key),
109        }
110    }
111
112    pub fn shutdown(&self) {
113        self.is_shutdown.store(true, Ordering::SeqCst);
114    }
115
116    pub async fn get(&self, request_args: RequestArgs) -> Result<Response, NetworkError> {
117        self.make_request(HttpMethod::GET, request_args).await
118    }
119
120    pub async fn post(
121        &self,
122        mut request_args: RequestArgs,
123        body: Option<Vec<u8>>,
124    ) -> Result<Response, NetworkError> {
125        request_args.body = body;
126        self.make_request(HttpMethod::POST, request_args).await
127    }
128
129    async fn make_request(
130        &self,
131        method: HttpMethod,
132        mut request_args: RequestArgs,
133    ) -> Result<Response, NetworkError> {
134        let is_shutdown = if let Some(is_shutdown) = &request_args.is_shutdown {
135            is_shutdown.clone()
136        } else {
137            self.is_shutdown.clone()
138        };
139
140        if self.disable_network {
141            log_d!(TAG, "Network is disabled, not making requests");
142            return Err(NetworkError::DisableNetworkOn(request_args.url));
143        }
144
145        request_args.populate_headers(self.headers.clone());
146
147        if request_args.disable_file_streaming.is_none() {
148            request_args.disable_file_streaming = Some(self.disable_file_streaming);
149        }
150
151        if request_args.ca_cert_pem.is_none() {
152            request_args.ca_cert_pem = self.ca_cert_pem.clone();
153        }
154
155        let mut merged_headers = request_args.headers.unwrap_or_default();
156        if !self.headers.is_empty() {
157            merged_headers.extend(self.headers.clone());
158        }
159        merged_headers.insert(
160            "STATSIG-CLIENT-TIME".into(),
161            Utc::now().timestamp_millis().to_string(),
162        );
163        request_args.headers = Some(merged_headers);
164
165        // passing down proxy config through request args
166        if let Some(proxy_config) = &self.proxy_config {
167            request_args.proxy_config = Some(proxy_config.clone());
168        }
169        let mut attempt = 0;
170
171        loop {
172            if let Some(key) = request_args.diagnostics_key {
173                self.ops_stats.add_marker(
174                    Marker::new(key, ActionType::Start, Some(StepType::NetworkRequest))
175                        .with_attempt(attempt)
176                        .with_url(request_args.url.clone()),
177                    None,
178                );
179            }
180            if is_shutdown.load(Ordering::SeqCst) {
181                log_i!(TAG, "{}", SHUTDOWN_ERROR);
182                return Err(NetworkError::ShutdownError(request_args.url));
183            }
184
185            let request_start = Instant::now();
186            let mut response = match self.net_provider.upgrade() {
187                Some(net_provider) => net_provider.send(&method, &request_args).await,
188                None => {
189                    return Err(NetworkError::RequestFailed(
190                        request_args.url,
191                        None,
192                        "Failed to get a NetworkProvider instance".to_string(),
193                    ));
194                }
195            };
196
197            let status = response.status_code;
198            let error_message = response
199                .error
200                .clone()
201                .unwrap_or_else(|| get_error_message_for_status(status, response.data.as_mut()));
202
203            let content_type = response
204                .data
205                .as_ref()
206                .and_then(|data| data.get_header_ref("content-type"));
207
208            log_d!(
209                TAG,
210                "Response url({}) status({:?}) content-type({:?})",
211                &request_args.url,
212                response.status_code,
213                content_type
214            );
215
216            let sdk_region_str = response
217                .data
218                .as_ref()
219                .and_then(|data| data.get_header_ref("x-statsig-region").cloned());
220            let success = (200..300).contains(&status.unwrap_or(0));
221            let duration_ms = request_start.elapsed().as_millis() as f64;
222            self.log_network_request_latency_to_ob(&request_args, status, success, duration_ms);
223
224            if let Some(key) = request_args.diagnostics_key {
225                let mut end_marker =
226                    Marker::new(key, ActionType::End, Some(StepType::NetworkRequest))
227                        .with_attempt(attempt)
228                        .with_url(request_args.url.clone())
229                        .with_is_success(success)
230                        .with_content_type(content_type.cloned())
231                        .with_sdk_region(sdk_region_str.map(|s| s.to_owned()));
232
233                if let Some(status_code) = status {
234                    end_marker = end_marker.with_status_code(status_code);
235                }
236
237                let error_map = if !error_message.is_empty() {
238                    let mut map = HashMap::new();
239                    map.insert("name".to_string(), "NetworkError".to_string());
240                    map.insert("message".to_string(), error_message.clone());
241                    let status_string = match status {
242                        Some(code) => code.to_string(),
243                        None => "None".to_string(),
244                    };
245                    map.insert("code".to_string(), status_string);
246                    Some(map)
247                } else {
248                    None
249                };
250
251                if let Some(error_map) = error_map {
252                    end_marker = end_marker.with_error(error_map);
253                }
254
255                self.ops_stats.add_marker(end_marker, None);
256            }
257
258            if success {
259                return Ok(response);
260            }
261
262            if NON_RETRY_CODES.contains(&status.unwrap_or(0)) {
263                let error = NetworkError::RequestNotRetryable(
264                    request_args.url.clone(),
265                    status,
266                    error_message,
267                );
268                self.log_warning(&error, &request_args);
269                return Err(error);
270            }
271
272            if attempt >= request_args.retries {
273                let error = NetworkError::RetriesExhausted(
274                    request_args.url.clone(),
275                    status,
276                    attempt + 1,
277                    error_message,
278                );
279                self.log_warning(&error, &request_args);
280                return Err(error);
281            }
282
283            attempt += 1;
284            let backoff_ms = 2_u64.pow(attempt) * 100;
285
286            log_i!(
287                TAG, "Network request failed with status code {} (attempt {}/{}), will retry after {}ms...\n{}",
288                status.map_or("unknown".to_string(), |s| s.to_string()),
289                attempt,
290                request_args.retries + 1,
291                backoff_ms,
292                error_message
293            );
294
295            tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
296        }
297    }
298
299    pub fn mute_network_error_log(mut self) -> Self {
300        self.silent_on_network_failure = true;
301        self
302    }
303
304    // Logging helpers
305    fn log_warning(&self, error: &NetworkError, args: &RequestArgs) {
306        let exception = error.name();
307
308        log_w!(TAG, "{}", error);
309        if !self.silent_on_network_failure {
310            let dedupe_key = format!("{:?}", args.diagnostics_key);
311            self.ops_stats.log_error(ErrorBoundaryEvent {
312                tag: TAG.to_string(),
313                exception: exception.to_string(),
314                bypass_dedupe: false,
315                info: serde_json::to_string(error).unwrap_or_default(),
316                dedupe_key: Some(dedupe_key),
317                extra: None,
318            });
319        }
320    }
321
322    // ------------------------------------------------------------
323    // Observability Logging Helpers (OB only) - START
324    // ------------------------------------------------------------
325    fn log_network_request_latency_to_ob(
326        &self,
327        request_args: &RequestArgs,
328        status: Option<u16>,
329        success: bool,
330        duration_ms: f64,
331    ) {
332        let url = request_args.url.as_str();
333        if !should_log_network_request_latency(url) {
334            return;
335        }
336
337        let status_code = status
338            .map(|code| code.to_string())
339            .unwrap_or("none".to_string());
340        let tags = get_network_request_latency_tags(
341            request_args,
342            status_code,
343            success,
344            self.loggable_sdk_key.clone(),
345        );
346
347        self.ops_stats.log(ObservabilityEvent::new_event(
348            MetricType::Dist,
349            NETWORK_REQUEST_LATENCY_METRIC.to_string(),
350            duration_ms,
351            Some(tags),
352        ));
353    }
354}
355
356fn get_network_request_latency_tags(
357    request_args: &RequestArgs,
358    status_code: String,
359    success: bool,
360    loggable_sdk_key: String,
361) -> HashMap<String, String> {
362    let (source_service, request_path) = get_source_service_and_request_path(&request_args.url);
363    let mut tags = HashMap::from([
364        (REQUEST_PATH_TAG.to_string(), request_path),
365        (STATUS_CODE_TAG.to_string(), status_code),
366        (IS_SUCCESS_TAG.to_string(), success.to_string()),
367        (SDK_KEY_TAG.to_string(), loggable_sdk_key),
368        (SOURCE_SERVICE_TAG.to_string(), source_service),
369    ]);
370    if let Some(id_list_file_id) = request_args
371        .id_list_file_id
372        .as_ref()
373        .filter(|id| !id.is_empty())
374    {
375        tags.insert(ID_LIST_FILE_ID_TAG.to_string(), id_list_file_id.clone());
376    }
377
378    tags
379}
380
381fn is_latency_loggable_endpoint(endpoint: &str) -> bool {
382    endpoint == DOWNLOAD_CONFIG_SPECS_ENDPOINT
383        || endpoint == GET_ID_LISTS_ENDPOINT
384        || endpoint == DOWNLOAD_ID_LIST_FILE_ENDPOINT
385}
386
387fn get_version_and_endpoint_for_latency<'a>(
388    segments: &'a [&'a str],
389) -> Option<(usize, &'a str, &'a str)> {
390    // Find a known endpoint pattern, then verify the segment right before it is `/v{number}`.
391    segments
392        .iter()
393        .enumerate()
394        .find_map(|(endpoint_index, endpoint_segment)| {
395            if !is_latency_loggable_endpoint(endpoint_segment) || endpoint_index == 0 {
396                return None;
397            }
398
399            let version_index = endpoint_index - 1;
400            let version_segment = segments[version_index];
401            is_version_segment(version_segment).then_some((
402                version_index,
403                version_segment,
404                *endpoint_segment,
405            ))
406        })
407}
408
409fn should_log_network_request_latency(url: &str) -> bool {
410    let (_, raw_path) = split_host_and_path(url);
411    let normalized_path = strip_query_and_fragment(raw_path).trim_start_matches('/');
412    let segments: Vec<&str> = normalized_path
413        .split('/')
414        .filter(|segment| !segment.is_empty())
415        .collect();
416
417    get_version_and_endpoint_for_latency(&segments).is_some()
418}
419
420fn with_host_prefix(host_prefix: &str, path: &str) -> String {
421    if host_prefix.is_empty() {
422        path.to_string()
423    } else {
424        format!("{host_prefix}{path}")
425    }
426}
427
428fn get_source_service_and_request_path(url: &str) -> (String, String) {
429    let (host_prefix, raw_path) = split_host_and_path(url);
430    let normalized_path = strip_query_and_fragment(raw_path).trim_start_matches('/');
431    let segments: Vec<&str> = normalized_path
432        .split('/')
433        .filter(|segment| !segment.is_empty())
434        .collect();
435
436    if let Some((version_index, version_segment, endpoint_segment)) =
437        get_version_and_endpoint_for_latency(&segments)
438    {
439        let request_path = format!("/{version_segment}/{endpoint_segment}");
440        let source_service_suffix = segments[..version_index].join("/");
441        let source_service = with_host_prefix(&host_prefix, &source_service_suffix)
442            .trim_end_matches('/')
443            .to_string();
444        return (source_service, request_path);
445    }
446
447    let fallback_request_path: String = normalized_path
448        .chars()
449        .take(MAX_REQUEST_PATH_LENGTH)
450        .collect();
451    let request_path = if fallback_request_path.is_empty() {
452        "/".to_string()
453    } else {
454        format!("/{fallback_request_path}")
455    };
456    let source_service = host_prefix.trim_end_matches('/').to_string();
457    (source_service, request_path)
458}
459
460#[cfg(test)]
461fn get_request_path(url: &str) -> String {
462    get_source_service_and_request_path(url).1
463}
464
465// ------------------------------------------------------------
466// Observability Logging Helpers (OB only) - END
467// ------------------------------------------------------------
468
469fn get_error_message_for_status(
470    status: Option<u16>,
471    data: Option<&mut super::ResponseData>,
472) -> String {
473    if (200..300).contains(&status.unwrap_or(0)) {
474        return String::new();
475    }
476
477    let mut message = String::new();
478    if let Some(data) = data {
479        let lossy_str = data.read_to_string().unwrap_or_default();
480        if lossy_str.is_ascii() {
481            message = lossy_str.to_string();
482        }
483    }
484
485    let status_value = match status {
486        Some(code) => code,
487        None => return format!("HTTP Error None: {message}"),
488    };
489
490    let generic_message = match status_value {
491        400 => "Bad Request",
492        401 => "Unauthorized",
493        403 => "Forbidden",
494        404 => "Not Found",
495        405 => "Method Not Allowed",
496        406 => "Not Acceptable",
497        408 => "Request Timeout",
498        500 => "Internal Server Error",
499        502 => "Bad Gateway",
500        503 => "Service Unavailable",
501        504 => "Gateway Timeout",
502        0 => "Unknown Error",
503        _ => return format!("HTTP Error {status_value}: {message}"),
504    };
505
506    if message.is_empty() {
507        return generic_message.to_string();
508    }
509
510    format!("{generic_message}: {message}")
511}
512
513#[cfg(test)]
514mod tests {
515    use super::{
516        get_network_request_latency_tags, get_request_path, get_source_service_and_request_path,
517        should_log_network_request_latency, ID_LIST_FILE_ID_TAG, REQUEST_PATH_TAG,
518    };
519    use crate::networking::RequestArgs;
520
521    #[test]
522    fn test_get_request_path_with_sample_urls() {
523        assert_eq!(
524            get_request_path("https://api.statsigcdn.com/v1/download_id_list_file/3wHgh0FhoQH0p"),
525            "/v1/download_id_list_file"
526        );
527        assert_eq!(
528            get_request_path("https://api.statsigcdn.com/v1/download_id_list_file/Q9mXcz7L1P43tRb8kV2dHyw%2FM6nJf0Ae5uTqsrC4Gp9KZ?foo=bar"),
529            "/v1/download_id_list_file"
530        );
531        assert_eq!(
532            get_request_path("https://api.statsig.com/v1/get_id_lists/secret-abcdef"),
533            "/v1/get_id_lists"
534        );
535        assert_eq!(
536            get_request_path("https://api.statsigcdn.com/v2/download_config_specs/secret-123456"),
537            "/v2/download_config_specs"
538        );
539    }
540
541    #[test]
542    fn test_should_log_network_request_latency_for_supported_endpoints() {
543        assert!(!should_log_network_request_latency(
544            "https://api.statsig.com/v1/log_event"
545        ));
546        assert!(!should_log_network_request_latency(
547            "https://api.statsig.com/v1/sdk_exception"
548        ));
549        assert!(should_log_network_request_latency(
550            "https://api.statsig.com/v1/get_id_lists/secret-abcdef"
551        ));
552        assert!(should_log_network_request_latency(
553            "https://api.statsigcdn.com/v2/download_config_specs/secret-123456"
554        ));
555        assert!(should_log_network_request_latency(
556            "https://api.statsigcdn.com/v1/download_id_list_file/3wHgh0FhoQH0p"
557        ));
558    }
559
560    #[test]
561    fn test_get_source_service_and_request_path() {
562        let (source_service, request_path) = get_source_service_and_request_path(
563            "http://127.0.0.1:12345/mock-uuid/v2/download_config_specs/secret-key.json?x=1",
564        );
565        assert_eq!(source_service, "http://127.0.0.1:12345/mock-uuid");
566        assert_eq!(request_path, "/v2/download_config_specs");
567    }
568
569    #[test]
570    fn test_network_latency_tags_include_id_list_file_id_only_when_present() {
571        let mut request_args = RequestArgs {
572            url: "https://api.statsigcdn.com/v1/download_id_list_file/file-123".to_string(),
573            id_list_file_id: Some("file-123".to_string()),
574            ..RequestArgs::new()
575        };
576
577        let tags = get_network_request_latency_tags(
578            &request_args,
579            "200".to_string(),
580            true,
581            "client-key-123".to_string(),
582        );
583        assert_eq!(tags.get(ID_LIST_FILE_ID_TAG), Some(&"file-123".to_string()));
584        assert_eq!(
585            tags.get(REQUEST_PATH_TAG),
586            Some(&"/v1/download_id_list_file".to_string())
587        );
588
589        request_args.id_list_file_id = Some(String::new());
590        let tags_without_id = get_network_request_latency_tags(
591            &request_args,
592            "200".to_string(),
593            true,
594            "client-key-123".to_string(),
595        );
596        assert!(!tags_without_id.contains_key(ID_LIST_FILE_ID_TAG));
597    }
598}