Skip to main content

statsig_rust/networking/
network_client.rs

1use chrono::Utc;
2
3use super::network_error::NetworkError;
4use super::providers::get_network_provider;
5use super::{HttpMethod, NetworkProvider, RequestArgs, Response};
6use crate::networking::proxy_config::ProxyConfig;
7use crate::observability::observability_client_adapter::{MetricType, ObservabilityEvent};
8use crate::observability::ops_stats::{OpsStatsForInstance, OPS_STATS};
9use crate::observability::ErrorBoundaryEvent;
10use crate::sdk_diagnostics::marker::{ActionType, Marker, StepType};
11use crate::utils::{
12    get_loggable_sdk_key, is_version_segment, split_host_and_path, strip_query_and_fragment,
13};
14use crate::{log_d, log_i, log_w, StatsigOptions};
15use std::collections::HashMap;
16use std::sync::atomic::{AtomicBool, Ordering};
17use std::sync::{Arc, Weak};
18use std::time::{Duration, Instant};
19
20const NON_RETRY_CODES: [u16; 6] = [
21    400, // Bad Request
22    403, // Forbidden
23    413, // Payload Too Large
24    405, // Method Not Allowed
25    429, // Too Many Requests
26    501, // Not Implemented
27];
28const SHUTDOWN_ERROR: &str = "Request was aborted because the client is shutting down";
29
30const MAX_REQUEST_PATH_LENGTH: usize = 64;
31const DOWNLOAD_CONFIG_SPECS_ENDPOINT: &str = "download_config_specs";
32const GET_ID_LISTS_ENDPOINT: &str = "get_id_lists";
33const DOWNLOAD_ID_LIST_FILE_ENDPOINT: &str = "download_id_list_file";
34const NETWORK_REQUEST_LATENCY_METRIC: &str = "network_request.latency";
35const REQUEST_PATH_TAG: &str = "request_path";
36const STATUS_CODE_TAG: &str = "status_code";
37const IS_SUCCESS_TAG: &str = "is_success";
38const SDK_KEY_TAG: &str = "sdk_key";
39const SOURCE_SERVICE_TAG: &str = "source_service";
40const ID_LIST_FILE_ID_TAG: &str = "id_list_file_id";
41const DELTAS_USED_TAG: &str = "deltas_used";
42
43const TAG: &str = stringify!(NetworkClient);
44
45pub struct NetworkClient {
46    headers: HashMap<String, String>,
47    is_shutdown: Arc<AtomicBool>,
48    ops_stats: Arc<OpsStatsForInstance>,
49    net_provider: Weak<dyn NetworkProvider>,
50    disable_network: bool,
51    proxy_config: Option<ProxyConfig>,
52    ca_cert_pem: Option<Vec<u8>>,
53    silent_on_network_failure: bool,
54    disable_file_streaming: bool,
55    log_event_connection_reuse: bool,
56    loggable_sdk_key: String,
57}
58
59impl NetworkClient {
60    #[must_use]
61    pub fn new(
62        sdk_key: &str,
63        headers: Option<HashMap<String, String>>,
64        options: Option<&StatsigOptions>,
65    ) -> Self {
66        let net_provider = get_network_provider();
67        let (disable_network, proxy_config, ca_cert_pem, log_event_connection_reuse) = options
68            .map(|opts| {
69                let ca_cert_pem = opts
70                    .proxy_config
71                    .as_ref()
72                    .and_then(|cfg| cfg.ca_cert_path.as_ref())
73                    .and_then(|path| {
74                        if path.is_empty() {
75                            return None;
76                        }
77                        match std::fs::read(path) {
78                            Ok(bytes) => Some(bytes),
79                            Err(e) => {
80                                log_w!(
81                                    TAG,
82                                    "Failed to read proxy_config.ca_cert_path '{}': {}",
83                                    path,
84                                    e
85                                );
86                                None
87                            }
88                        }
89                    });
90                (
91                    opts.disable_network.unwrap_or(false),
92                    opts.proxy_config.clone(),
93                    ca_cert_pem,
94                    opts.log_event_connection_reuse.unwrap_or(false),
95                )
96            })
97            .unwrap_or((false, None, None, false));
98
99        NetworkClient {
100            headers: headers.unwrap_or_default(),
101            is_shutdown: Arc::new(AtomicBool::new(false)),
102            net_provider,
103            ops_stats: OPS_STATS.get_for_instance(sdk_key),
104            disable_network,
105            proxy_config,
106            ca_cert_pem,
107            silent_on_network_failure: false,
108            disable_file_streaming: options
109                .map(|opts| opts.disable_disk_access.unwrap_or(false))
110                .unwrap_or(false),
111            log_event_connection_reuse,
112            loggable_sdk_key: get_loggable_sdk_key(sdk_key),
113        }
114    }
115
116    pub fn shutdown(&self) {
117        self.is_shutdown.store(true, Ordering::SeqCst);
118    }
119
120    pub async fn get(&self, request_args: RequestArgs) -> Result<Response, NetworkError> {
121        self.make_request(HttpMethod::GET, request_args).await
122    }
123
124    pub async fn post(
125        &self,
126        mut request_args: RequestArgs,
127        body: Option<Vec<u8>>,
128    ) -> Result<Response, NetworkError> {
129        request_args.body = body;
130        self.make_request(HttpMethod::POST, request_args).await
131    }
132
133    async fn make_request(
134        &self,
135        method: HttpMethod,
136        mut request_args: RequestArgs,
137    ) -> Result<Response, NetworkError> {
138        let is_shutdown = if let Some(is_shutdown) = &request_args.is_shutdown {
139            is_shutdown.clone()
140        } else {
141            self.is_shutdown.clone()
142        };
143
144        if self.disable_network {
145            log_d!(TAG, "Network is disabled, not making requests");
146            return Err(NetworkError::DisableNetworkOn(request_args.url));
147        }
148
149        request_args.populate_headers(self.headers.clone());
150
151        if request_args.disable_file_streaming.is_none() {
152            request_args.disable_file_streaming = Some(self.disable_file_streaming);
153        }
154
155        if request_args.ca_cert_pem.is_none() {
156            request_args.ca_cert_pem = self.ca_cert_pem.clone();
157        }
158
159        if self.log_event_connection_reuse && !request_args.log_event_connection_reuse {
160            request_args.log_event_connection_reuse = true;
161        }
162
163        let mut merged_headers = request_args.headers.unwrap_or_default();
164        if !self.headers.is_empty() {
165            merged_headers.extend(self.headers.clone());
166        }
167        merged_headers.insert(
168            "STATSIG-CLIENT-TIME".into(),
169            Utc::now().timestamp_millis().to_string(),
170        );
171        request_args.headers = Some(merged_headers);
172
173        // passing down proxy config through request args
174        if let Some(proxy_config) = &self.proxy_config {
175            request_args.proxy_config = Some(proxy_config.clone());
176        }
177        let mut attempt = 0;
178
179        loop {
180            if let Some(key) = request_args.diagnostics_key {
181                self.ops_stats.add_marker(
182                    Marker::new(key, ActionType::Start, Some(StepType::NetworkRequest))
183                        .with_attempt(attempt)
184                        .with_url(request_args.url.clone()),
185                    None,
186                );
187            }
188            if is_shutdown.load(Ordering::SeqCst) {
189                log_i!(TAG, "{}", SHUTDOWN_ERROR);
190                return Err(NetworkError::ShutdownError(request_args.url));
191            }
192
193            let request_start = Instant::now();
194            let mut response = match self.net_provider.upgrade() {
195                Some(net_provider) => net_provider.send(&method, &request_args).await,
196                None => {
197                    return Err(NetworkError::RequestFailed(
198                        request_args.url,
199                        None,
200                        "Failed to get a NetworkProvider instance".to_string(),
201                    ));
202                }
203            };
204
205            let status = response.status_code;
206            let error_message = response
207                .error
208                .clone()
209                .unwrap_or_else(|| get_error_message_for_status(status, response.data.as_mut()));
210
211            let content_type = response
212                .data
213                .as_ref()
214                .and_then(|data| data.get_header_ref("content-type"));
215
216            log_d!(
217                TAG,
218                "Response url({}) status({:?}) content-type({:?})",
219                &request_args.url,
220                response.status_code,
221                content_type
222            );
223
224            let sdk_region_str = response
225                .data
226                .as_ref()
227                .and_then(|data| data.get_header_ref("x-statsig-region").cloned());
228            let success = (200..300).contains(&status.unwrap_or(0));
229            let duration_ms = request_start.elapsed().as_millis() as f64;
230            self.log_network_request_latency_to_ob(&request_args, status, success, duration_ms);
231
232            if let Some(key) = request_args.diagnostics_key {
233                let mut end_marker =
234                    Marker::new(key, ActionType::End, Some(StepType::NetworkRequest))
235                        .with_attempt(attempt)
236                        .with_url(request_args.url.clone())
237                        .with_is_success(success)
238                        .with_content_type(content_type.cloned())
239                        .with_sdk_region(sdk_region_str.map(|s| s.to_owned()));
240
241                if let Some(status_code) = status {
242                    end_marker = end_marker.with_status_code(status_code);
243                }
244
245                let error_map = if !error_message.is_empty() {
246                    let mut map = HashMap::new();
247                    map.insert("name".to_string(), "NetworkError".to_string());
248                    map.insert("message".to_string(), error_message.clone());
249                    let status_string = match status {
250                        Some(code) => code.to_string(),
251                        None => "None".to_string(),
252                    };
253                    map.insert("code".to_string(), status_string);
254                    Some(map)
255                } else {
256                    None
257                };
258
259                if let Some(error_map) = error_map {
260                    end_marker = end_marker.with_error(error_map);
261                }
262
263                self.ops_stats.add_marker(end_marker, None);
264            }
265
266            if success {
267                return Ok(response);
268            }
269
270            if NON_RETRY_CODES.contains(&status.unwrap_or(0)) {
271                let error = NetworkError::RequestNotRetryable(
272                    request_args.url.clone(),
273                    status,
274                    error_message,
275                );
276                self.log_warning(&error, &request_args);
277                return Err(error);
278            }
279
280            if attempt >= request_args.retries {
281                let error = NetworkError::RetriesExhausted(
282                    request_args.url.clone(),
283                    status,
284                    attempt + 1,
285                    error_message,
286                );
287                self.log_warning(&error, &request_args);
288                return Err(error);
289            }
290
291            attempt += 1;
292            let backoff_ms = 2_u64.pow(attempt) * 100;
293
294            log_i!(
295                TAG, "Network request failed with status code {} (attempt {}/{}), will retry after {}ms...\n{}",
296                status.map_or("unknown".to_string(), |s| s.to_string()),
297                attempt,
298                request_args.retries + 1,
299                backoff_ms,
300                error_message
301            );
302
303            tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
304        }
305    }
306
307    pub fn mute_network_error_log(mut self) -> Self {
308        self.silent_on_network_failure = true;
309        self
310    }
311
312    // Logging helpers
313    fn log_warning(&self, error: &NetworkError, args: &RequestArgs) {
314        let exception = error.name();
315
316        log_w!(TAG, "{}", error);
317        if !self.silent_on_network_failure {
318            let dedupe_key = format!("{:?}", args.diagnostics_key);
319            self.ops_stats.log_error(ErrorBoundaryEvent {
320                tag: TAG.to_string(),
321                exception: exception.to_string(),
322                bypass_dedupe: false,
323                info: serde_json::to_string(error).unwrap_or_default(),
324                dedupe_key: Some(dedupe_key),
325                extra: None,
326            });
327        }
328    }
329
330    // ------------------------------------------------------------
331    // Observability Logging Helpers (OB only) - START
332    // ------------------------------------------------------------
333    fn log_network_request_latency_to_ob(
334        &self,
335        request_args: &RequestArgs,
336        status: Option<u16>,
337        success: bool,
338        duration_ms: f64,
339    ) {
340        let url = request_args.url.as_str();
341        if !should_log_network_request_latency(url) {
342            return;
343        }
344
345        let status_code = status
346            .map(|code| code.to_string())
347            .unwrap_or("none".to_string());
348        let tags = get_network_request_latency_tags(
349            request_args,
350            status_code,
351            success,
352            self.loggable_sdk_key.clone(),
353        );
354
355        self.ops_stats.log(ObservabilityEvent::new_event(
356            MetricType::Dist,
357            NETWORK_REQUEST_LATENCY_METRIC.to_string(),
358            duration_ms,
359            Some(tags),
360        ));
361    }
362}
363
364fn get_network_request_latency_tags(
365    request_args: &RequestArgs,
366    status_code: String,
367    success: bool,
368    loggable_sdk_key: String,
369) -> HashMap<String, String> {
370    let (source_service, request_path) = get_source_service_and_request_path(&request_args.url);
371    let mut tags = HashMap::from([
372        (REQUEST_PATH_TAG.to_string(), request_path),
373        (STATUS_CODE_TAG.to_string(), status_code),
374        (IS_SUCCESS_TAG.to_string(), success.to_string()),
375        (SDK_KEY_TAG.to_string(), loggable_sdk_key),
376        (SOURCE_SERVICE_TAG.to_string(), source_service),
377        (
378            DELTAS_USED_TAG.to_string(),
379            request_args.deltas_enabled.to_string(),
380        ),
381    ]);
382    if let Some(id_list_file_id) = request_args
383        .id_list_file_id
384        .as_ref()
385        .filter(|id| !id.is_empty())
386    {
387        tags.insert(ID_LIST_FILE_ID_TAG.to_string(), id_list_file_id.clone());
388    }
389
390    tags
391}
392
393fn is_latency_loggable_endpoint(endpoint: &str) -> bool {
394    endpoint == DOWNLOAD_CONFIG_SPECS_ENDPOINT
395        || endpoint == GET_ID_LISTS_ENDPOINT
396        || endpoint == DOWNLOAD_ID_LIST_FILE_ENDPOINT
397}
398
399fn get_version_and_endpoint_for_latency<'a>(
400    segments: &'a [&'a str],
401) -> Option<(usize, &'a str, &'a str)> {
402    // Find a known endpoint pattern, then verify the segment right before it is `/v{number}`.
403    segments
404        .iter()
405        .enumerate()
406        .find_map(|(endpoint_index, endpoint_segment)| {
407            if !is_latency_loggable_endpoint(endpoint_segment) || endpoint_index == 0 {
408                return None;
409            }
410
411            let version_index = endpoint_index - 1;
412            let version_segment = segments[version_index];
413            is_version_segment(version_segment).then_some((
414                version_index,
415                version_segment,
416                *endpoint_segment,
417            ))
418        })
419}
420
421fn should_log_network_request_latency(url: &str) -> bool {
422    let (_, raw_path) = split_host_and_path(url);
423    let normalized_path = strip_query_and_fragment(raw_path).trim_start_matches('/');
424    let segments: Vec<&str> = normalized_path
425        .split('/')
426        .filter(|segment| !segment.is_empty())
427        .collect();
428
429    get_version_and_endpoint_for_latency(&segments).is_some()
430}
431
432fn with_host_prefix(host_prefix: &str, path: &str) -> String {
433    if host_prefix.is_empty() {
434        path.to_string()
435    } else {
436        format!("{host_prefix}{path}")
437    }
438}
439
440fn get_source_service_and_request_path(url: &str) -> (String, String) {
441    let (host_prefix, raw_path) = split_host_and_path(url);
442    let normalized_path = strip_query_and_fragment(raw_path).trim_start_matches('/');
443    let segments: Vec<&str> = normalized_path
444        .split('/')
445        .filter(|segment| !segment.is_empty())
446        .collect();
447
448    if let Some((version_index, version_segment, endpoint_segment)) =
449        get_version_and_endpoint_for_latency(&segments)
450    {
451        let request_path = format!("/{version_segment}/{endpoint_segment}");
452        let source_service_suffix = segments[..version_index].join("/");
453        let source_service = with_host_prefix(&host_prefix, &source_service_suffix)
454            .trim_end_matches('/')
455            .to_string();
456        return (source_service, request_path);
457    }
458
459    let fallback_request_path: String = normalized_path
460        .chars()
461        .take(MAX_REQUEST_PATH_LENGTH)
462        .collect();
463    let request_path = if fallback_request_path.is_empty() {
464        "/".to_string()
465    } else {
466        format!("/{fallback_request_path}")
467    };
468    let source_service = host_prefix.trim_end_matches('/').to_string();
469    (source_service, request_path)
470}
471
472#[cfg(test)]
473fn get_request_path(url: &str) -> String {
474    get_source_service_and_request_path(url).1
475}
476
477// ------------------------------------------------------------
478// Observability Logging Helpers (OB only) - END
479// ------------------------------------------------------------
480
481fn get_error_message_for_status(
482    status: Option<u16>,
483    data: Option<&mut super::ResponseData>,
484) -> String {
485    if (200..300).contains(&status.unwrap_or(0)) {
486        return String::new();
487    }
488
489    let mut message = String::new();
490    if let Some(data) = data {
491        let lossy_str = data.read_to_string().unwrap_or_default();
492        if lossy_str.is_ascii() {
493            message = lossy_str.to_string();
494        }
495    }
496
497    let status_value = match status {
498        Some(code) => code,
499        None => return format!("HTTP Error None: {message}"),
500    };
501
502    let generic_message = match status_value {
503        400 => "Bad Request",
504        401 => "Unauthorized",
505        403 => "Forbidden",
506        404 => "Not Found",
507        405 => "Method Not Allowed",
508        406 => "Not Acceptable",
509        408 => "Request Timeout",
510        500 => "Internal Server Error",
511        502 => "Bad Gateway",
512        503 => "Service Unavailable",
513        504 => "Gateway Timeout",
514        0 => "Unknown Error",
515        _ => return format!("HTTP Error {status_value}: {message}"),
516    };
517
518    if message.is_empty() {
519        return generic_message.to_string();
520    }
521
522    format!("{generic_message}: {message}")
523}
524
525#[cfg(test)]
526mod tests {
527    use super::{
528        get_network_request_latency_tags, get_request_path, get_source_service_and_request_path,
529        should_log_network_request_latency, DELTAS_USED_TAG, ID_LIST_FILE_ID_TAG, REQUEST_PATH_TAG,
530    };
531    use crate::networking::RequestArgs;
532
533    #[test]
534    fn test_get_request_path_with_sample_urls() {
535        assert_eq!(
536            get_request_path("https://api.statsigcdn.com/v1/download_id_list_file/3wHgh0FhoQH0p"),
537            "/v1/download_id_list_file"
538        );
539        assert_eq!(
540            get_request_path("https://api.statsigcdn.com/v1/download_id_list_file/Q9mXcz7L1P43tRb8kV2dHyw%2FM6nJf0Ae5uTqsrC4Gp9KZ?foo=bar"),
541            "/v1/download_id_list_file"
542        );
543        assert_eq!(
544            get_request_path("https://api.statsig.com/v1/get_id_lists/secret-abcdef"),
545            "/v1/get_id_lists"
546        );
547        assert_eq!(
548            get_request_path("https://api.statsigcdn.com/v2/download_config_specs/secret-123456"),
549            "/v2/download_config_specs"
550        );
551    }
552
553    #[test]
554    fn test_should_log_network_request_latency_for_supported_endpoints() {
555        assert!(!should_log_network_request_latency(
556            "https://api.statsig.com/v1/log_event"
557        ));
558        assert!(!should_log_network_request_latency(
559            "https://api.statsig.com/v1/sdk_exception"
560        ));
561        assert!(should_log_network_request_latency(
562            "https://api.statsig.com/v1/get_id_lists/secret-abcdef"
563        ));
564        assert!(should_log_network_request_latency(
565            "https://api.statsigcdn.com/v2/download_config_specs/secret-123456"
566        ));
567        assert!(should_log_network_request_latency(
568            "https://api.statsigcdn.com/v1/download_id_list_file/3wHgh0FhoQH0p"
569        ));
570    }
571
572    #[test]
573    fn test_get_source_service_and_request_path() {
574        let (source_service, request_path) = get_source_service_and_request_path(
575            "http://127.0.0.1:12345/mock-uuid/v2/download_config_specs/secret-key.json?x=1",
576        );
577        assert_eq!(source_service, "http://127.0.0.1:12345/mock-uuid");
578        assert_eq!(request_path, "/v2/download_config_specs");
579    }
580
581    #[test]
582    fn test_network_latency_tags_include_id_list_file_id_only_when_present() {
583        let mut request_args = RequestArgs {
584            url: "https://api.statsigcdn.com/v1/download_id_list_file/file-123".to_string(),
585            id_list_file_id: Some("file-123".to_string()),
586            ..RequestArgs::new()
587        };
588
589        let tags = get_network_request_latency_tags(
590            &request_args,
591            "200".to_string(),
592            true,
593            "client-key-123".to_string(),
594        );
595        assert_eq!(tags.get(ID_LIST_FILE_ID_TAG), Some(&"file-123".to_string()));
596        assert_eq!(tags.get(DELTAS_USED_TAG), Some(&"false".to_string()));
597        assert_eq!(
598            tags.get(REQUEST_PATH_TAG),
599            Some(&"/v1/download_id_list_file".to_string())
600        );
601
602        request_args.id_list_file_id = Some(String::new());
603        request_args.deltas_enabled = true;
604        let tags_without_id = get_network_request_latency_tags(
605            &request_args,
606            "200".to_string(),
607            true,
608            "client-key-123".to_string(),
609        );
610        assert!(!tags_without_id.contains_key(ID_LIST_FILE_ID_TAG));
611        assert_eq!(
612            tags_without_id.get(DELTAS_USED_TAG),
613            Some(&"true".to_string())
614        );
615    }
616}