statsig_rust/networking/providers/
net_provider_reqwest.rs

1use std::collections::HashMap;
2use std::time::Duration;
3
4use async_trait::async_trait;
5
6use crate::{
7    log_w,
8    networking::{
9        http_types::{HttpMethod, RequestArgs, Response},
10        NetworkProvider,
11    },
12};
13
14use reqwest::Method;
15
16const TAG: &str = "NetworkProviderReqwest";
17
18pub struct NetworkProviderReqwest {}
19
20#[async_trait]
21impl NetworkProvider for NetworkProviderReqwest {
22    async fn send(&self, method: &HttpMethod, args: &RequestArgs) -> Response {
23        if let Some(is_shutdown) = &args.is_shutdown {
24            if is_shutdown.load(std::sync::atomic::Ordering::SeqCst) {
25                return Response {
26                    status_code: 0,
27                    data: None,
28                    error: Some("Request was shutdown".to_string()),
29                    headers: None,
30                };
31            }
32        }
33
34        let request = self.build_request(method, args);
35
36        let error;
37        let mut status_code = 0;
38        let mut data = None;
39        let mut headers = None;
40
41        match request.send().await {
42            Ok(response) => {
43                status_code = response.status().as_u16();
44                headers = get_response_headers(&response);
45                data = response.bytes().await.ok().map(|bytes| bytes.to_vec());
46                error = None;
47            }
48            Err(e) => {
49                let error_message = get_error_message(e);
50                log_w!(TAG, "Request Error: {} {}", &args.url, error_message);
51                error = Some(error_message);
52            }
53        }
54
55        Response {
56            status_code,
57            data,
58            error,
59            headers,
60        }
61    }
62}
63
64impl NetworkProviderReqwest {
65    fn build_request(
66        &self,
67        method: &HttpMethod,
68        request_args: &RequestArgs,
69    ) -> reqwest::RequestBuilder {
70        let method_actual = match method {
71            HttpMethod::GET => Method::GET,
72            HttpMethod::POST => Method::POST,
73        };
74        let is_post = method_actual == Method::POST;
75
76        let client = reqwest::Client::new();
77        let mut request = client.request(method_actual, &request_args.url);
78
79        let timeout_duration = match request_args.timeout_ms > 0 {
80            true => Duration::from_millis(request_args.timeout_ms),
81            false => Duration::from_secs(10),
82        };
83        request = request.timeout(timeout_duration);
84
85        if let Some(headers) = &request_args.headers {
86            for (key, value) in headers {
87                request = request.header(key, value);
88            }
89        }
90
91        if let Some(params) = &request_args.query_params {
92            request = request.query(params);
93        }
94
95        if is_post {
96            let bytes = match &request_args.body {
97                Some(b) => b.clone(),
98                None => vec![],
99            };
100            let byte_len = bytes.len();
101
102            request = request.body(bytes);
103            request = request.header("Content-Length", byte_len.to_string());
104        }
105
106        request
107    }
108}
109
110fn get_error_message(error: reqwest::Error) -> String {
111    let mut error_message = error.to_string();
112
113    if let Some(url_error) = error.url() {
114        error_message.push_str(&format!(". URL: {}", url_error));
115    }
116
117    if let Some(status_error) = error.status() {
118        error_message.push_str(&format!(". Status: {}", status_error));
119    }
120
121    error_message
122}
123
124fn get_response_headers(response: &reqwest::Response) -> Option<HashMap<String, String>> {
125    let headers = response.headers();
126    if headers.is_empty() {
127        return None;
128    }
129
130    let mut headers_map = HashMap::new();
131    for (key, value) in headers {
132        headers_map.insert(key.to_string(), value.to_str().unwrap_or("").to_string());
133    }
134
135    Some(headers_map)
136}