Skip to main content

statsig_rust/networking/providers/
net_provider_reqwest.rs

1use std::collections::HashMap;
2use std::io::{BufReader, Seek, SeekFrom, Write};
3use std::time::Duration;
4
5use async_trait::async_trait;
6
7use crate::log_d;
8use crate::{
9    log_e, log_w,
10    networking::{
11        http_types::{HttpMethod, RequestArgs, Response, ResponseData},
12        NetworkProvider,
13    },
14    StatsigErr,
15};
16
17use crate::networking::proxy_config::ProxyConfig;
18use reqwest::Method;
19
20const TAG: &str = "NetworkProviderReqwest";
21
22pub struct NetworkProviderReqwest {
23    has_file_write_access: bool,
24}
25
26impl NetworkProviderReqwest {
27    pub fn new() -> Self {
28        Self {
29            has_file_write_access: tempfile::tempfile().is_ok(),
30        }
31    }
32}
33
34impl Default for NetworkProviderReqwest {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40#[async_trait]
41impl NetworkProvider for NetworkProviderReqwest {
42    async fn send(&self, method: &HttpMethod, args: &RequestArgs) -> Response {
43        if let Some(is_shutdown) = &args.is_shutdown {
44            if is_shutdown.load(std::sync::atomic::Ordering::SeqCst) {
45                return Response {
46                    status_code: None,
47                    data: None,
48                    error: Some("Request was shutdown".to_string()),
49                };
50            }
51        }
52
53        let request = self.build_request(method, args);
54
55        let mut error = None;
56        let mut status_code = None;
57        let mut data = None;
58
59        match request.send().await {
60            Ok(response) => {
61                status_code = Some(response.status().as_u16());
62
63                let data_result =
64                    if !self.has_file_write_access || args.disable_file_streaming == Some(true) {
65                        Self::write_response_to_in_memory_buffer(response).await
66                    } else {
67                        Self::write_response_to_temp_file(response).await
68                    };
69
70                match data_result {
71                    Ok(response_data) => data = Some(response_data),
72                    Err(e) => {
73                        error = Some(e.to_string());
74                    }
75                }
76            }
77            Err(e) => {
78                let error_message = get_error_message(e);
79                error = Some(error_message);
80            }
81        }
82
83        Response {
84            status_code,
85            data,
86            error,
87        }
88    }
89}
90
91impl NetworkProviderReqwest {
92    fn build_request(
93        &self,
94        method: &HttpMethod,
95        request_args: &RequestArgs,
96    ) -> reqwest::RequestBuilder {
97        let method_actual = match method {
98            HttpMethod::GET => Method::GET,
99            HttpMethod::POST => Method::POST,
100        };
101        let is_post = method_actual == Method::POST;
102
103        let mut client_builder = reqwest::Client::builder();
104
105        // configure proxy if available
106        if let Some(proxy_config) = request_args.proxy_config.as_ref() {
107            client_builder = Self::configure_proxy(client_builder, proxy_config);
108        }
109
110        if let Some(ca_cert_pem) = &request_args.ca_cert_pem {
111            match reqwest::Certificate::from_pem(ca_cert_pem) {
112                Ok(cert) => {
113                    client_builder = client_builder.add_root_certificate(cert);
114                }
115                Err(e) => {
116                    log_e!(TAG, "Failed to parse network CA cert PEM: {}", e);
117                }
118            }
119        }
120
121        let client = client_builder.build().unwrap_or_else(|e| {
122            log_e!(TAG, "Failed to build reqwest client with proxy config: {}. Falling back to default client.", e);
123            reqwest::Client::new()
124        });
125
126        let mut request = client.request(method_actual, &request_args.url);
127
128        let timeout_duration = match request_args.timeout_ms > 0 {
129            true => Duration::from_millis(request_args.timeout_ms),
130            false => Duration::from_secs(10),
131        };
132        request = request.timeout(timeout_duration);
133
134        if let Some(headers) = &request_args.headers {
135            for (key, value) in headers {
136                request = request.header(key, value);
137            }
138        }
139
140        if let Some(params) = &request_args.query_params {
141            request = request.query(params);
142        }
143
144        if is_post {
145            let bytes = match &request_args.body {
146                Some(b) => b.clone(),
147                None => vec![],
148            };
149            let byte_len = bytes.len();
150
151            request = request.body(bytes);
152            request = request.header("Content-Length", byte_len.to_string());
153        }
154
155        request
156    }
157
158    fn configure_proxy(
159        client_builder: reqwest::ClientBuilder,
160        proxy_config: &ProxyConfig,
161    ) -> reqwest::ClientBuilder {
162        let (Some(host), Some(port)) = (&proxy_config.proxy_host, &proxy_config.proxy_port) else {
163            return client_builder;
164        };
165
166        let proxy_url = format!(
167            "{}://{}:{}",
168            proxy_config.proxy_protocol.as_deref().unwrap_or("http"),
169            host,
170            port
171        );
172
173        let Ok(proxy) = reqwest::Proxy::all(&proxy_url) else {
174            log_w!(TAG, "Failed to create proxy for URL: {}", proxy_url);
175            return client_builder;
176        };
177
178        let Some(auth) = &proxy_config.proxy_auth else {
179            return client_builder.proxy(proxy);
180        };
181
182        let Some((username, password)) = auth.split_once(':') else {
183            log_w!(
184                TAG,
185                "Invalid proxy auth format. Expected 'username:password'"
186            );
187            return client_builder.proxy(proxy);
188        };
189
190        client_builder.proxy(proxy.basic_auth(username, password))
191    }
192
193    async fn write_response_to_temp_file(
194        response: reqwest::Response,
195    ) -> Result<ResponseData, StatsigErr> {
196        let headers = get_response_headers(&response);
197        let mut response = response;
198        let mut temp_file = tempfile::spooled_tempfile(1024 * 1024 * 2); // 2MB
199
200        let mut total_bytes = 0;
201        while let Some(item) = response
202            .chunk()
203            .await
204            .map_err(|e| StatsigErr::FileError(e.to_string()))?
205        {
206            total_bytes += item.len();
207            temp_file
208                .write_all(&item)
209                .map_err(|e| StatsigErr::FileError(e.to_string()))?;
210        }
211
212        temp_file
213            .seek(SeekFrom::Start(0))
214            .map_err(|e| StatsigErr::FileError(e.to_string()))?;
215
216        let reader = BufReader::new(temp_file);
217
218        log_d!(TAG, "Wrote {} bytes to spooled temp file", total_bytes);
219
220        Ok(ResponseData::from_stream_with_headers(
221            Box::new(reader),
222            headers,
223        ))
224    }
225
226    async fn write_response_to_in_memory_buffer(
227        response: reqwest::Response,
228    ) -> Result<ResponseData, StatsigErr> {
229        let headers = get_response_headers(&response);
230        let bytes = response
231            .bytes()
232            .await
233            .map_err(|e| StatsigErr::SerializationError(e.to_string()))?;
234
235        log_d!(TAG, "Wrote {} bytes to in-memory buffer", bytes.len());
236
237        Ok(ResponseData::from_bytes_with_headers(
238            bytes.to_vec(),
239            headers,
240        ))
241    }
242}
243
244fn get_error_message(error: reqwest::Error) -> String {
245    let mut error_message = error.to_string();
246
247    if let Some(url_error) = error.url() {
248        error_message.push_str(&format!(". URL: {}", url_error));
249    }
250
251    if let Some(status_error) = error.status() {
252        error_message.push_str(&format!(". Status: {}", status_error));
253    }
254
255    error_message
256}
257
258fn get_response_headers(response: &reqwest::Response) -> Option<HashMap<String, String>> {
259    let headers = response.headers();
260    if headers.is_empty() {
261        return None;
262    }
263
264    let mut headers_map = HashMap::new();
265    for (key, value) in headers {
266        headers_map.insert(key.to_string(), value.to_str().unwrap_or("").to_string());
267    }
268
269    Some(headers_map)
270}