statsig_rust/networking/providers/
net_provider_reqwest.rs

1use std::collections::HashMap;
2use std::time::Duration;
3
4use async_trait::async_trait;
5
6use crate::{
7    log_e, log_w,
8    networking::{
9        http_types::{HttpMethod, RequestArgs, Response},
10        NetworkProvider,
11    },
12};
13
14use crate::networking::proxy_config::ProxyConfig;
15use reqwest::Method;
16
17const TAG: &str = "NetworkProviderReqwest";
18
19pub struct NetworkProviderReqwest {}
20
21#[async_trait]
22impl NetworkProvider for NetworkProviderReqwest {
23    async fn send(&self, method: &HttpMethod, args: &RequestArgs) -> Response {
24        if let Some(is_shutdown) = &args.is_shutdown {
25            if is_shutdown.load(std::sync::atomic::Ordering::SeqCst) {
26                return Response {
27                    status_code: 0,
28                    data: None,
29                    error: Some("Request was shutdown".to_string()),
30                    headers: None,
31                };
32            }
33        }
34
35        let request = self.build_request(method, args);
36
37        let error;
38        let mut status_code = 0;
39        let mut data = None;
40        let mut headers = None;
41
42        match request.send().await {
43            Ok(response) => {
44                status_code = response.status().as_u16();
45                headers = get_response_headers(&response);
46                data = response.bytes().await.ok().map(|bytes| bytes.to_vec());
47                error = None;
48            }
49            Err(e) => {
50                let error_message = get_error_message(e);
51                log_w!(TAG, "Request Error: {} {}", &args.url, error_message);
52                error = Some(error_message);
53            }
54        }
55
56        Response {
57            status_code,
58            data,
59            error,
60            headers,
61        }
62    }
63}
64
65impl NetworkProviderReqwest {
66    fn build_request(
67        &self,
68        method: &HttpMethod,
69        request_args: &RequestArgs,
70    ) -> reqwest::RequestBuilder {
71        let method_actual = match method {
72            HttpMethod::GET => Method::GET,
73            HttpMethod::POST => Method::POST,
74        };
75        let is_post = method_actual == Method::POST;
76
77        let mut client_builder = reqwest::Client::builder();
78
79        // configure proxy if available
80        if let Some(proxy_config) = request_args.proxy_config.as_ref() {
81            client_builder = Self::configure_proxy(client_builder, proxy_config);
82        }
83
84        let client = client_builder.build().unwrap_or_else(|e| {
85            log_e!(TAG, "Failed to build reqwest client with proxy config: {}. Falling back to default client.", e);
86            reqwest::Client::new()
87        });
88
89        let mut request = client.request(method_actual, &request_args.url);
90
91        let timeout_duration = match request_args.timeout_ms > 0 {
92            true => Duration::from_millis(request_args.timeout_ms),
93            false => Duration::from_secs(10),
94        };
95        request = request.timeout(timeout_duration);
96
97        if let Some(headers) = &request_args.headers {
98            for (key, value) in headers {
99                request = request.header(key, value);
100            }
101        }
102
103        if let Some(params) = &request_args.query_params {
104            request = request.query(params);
105        }
106
107        if is_post {
108            let bytes = match &request_args.body {
109                Some(b) => b.clone(),
110                None => vec![],
111            };
112            let byte_len = bytes.len();
113
114            request = request.body(bytes);
115            request = request.header("Content-Length", byte_len.to_string());
116        }
117
118        request
119    }
120
121    fn configure_proxy(
122        client_builder: reqwest::ClientBuilder,
123        proxy_config: &ProxyConfig,
124    ) -> reqwest::ClientBuilder {
125        let (Some(host), Some(port)) = (&proxy_config.proxy_host, &proxy_config.proxy_port) else {
126            return client_builder;
127        };
128
129        let proxy_url = format!(
130            "{}://{}:{}",
131            proxy_config.proxy_protocol.as_deref().unwrap_or("http"),
132            host,
133            port
134        );
135
136        let Ok(proxy) = reqwest::Proxy::all(&proxy_url) else {
137            log_w!(TAG, "Failed to create proxy for URL: {}", proxy_url);
138            return client_builder;
139        };
140
141        let Some(auth) = &proxy_config.proxy_auth else {
142            return client_builder.proxy(proxy);
143        };
144
145        let Some((username, password)) = auth.split_once(':') else {
146            log_w!(
147                TAG,
148                "Invalid proxy auth format. Expected 'username:password'"
149            );
150            return client_builder.proxy(proxy);
151        };
152
153        client_builder.proxy(proxy.basic_auth(username, password))
154    }
155}
156
157fn get_error_message(error: reqwest::Error) -> String {
158    let mut error_message = error.to_string();
159
160    if let Some(url_error) = error.url() {
161        error_message.push_str(&format!(". URL: {}", url_error));
162    }
163
164    if let Some(status_error) = error.status() {
165        error_message.push_str(&format!(". Status: {}", status_error));
166    }
167
168    error_message
169}
170
171fn get_response_headers(response: &reqwest::Response) -> Option<HashMap<String, String>> {
172    let headers = response.headers();
173    if headers.is_empty() {
174        return None;
175    }
176
177    let mut headers_map = HashMap::new();
178    for (key, value) in headers {
179        headers_map.insert(key.to_string(), value.to_str().unwrap_or("").to_string());
180    }
181
182    Some(headers_map)
183}