Skip to main content

statsig_rust/networking/providers/
net_provider_reqwest.rs

1use std::collections::HashMap;
2use std::io::{BufReader, Seek, SeekFrom, Write};
3use std::time::Duration;
4
5use async_trait::async_trait;
6
7use crate::log_d;
8use crate::{
9    log_e, log_w,
10    networking::{
11        http_types::{HttpMethod, RequestArgs, Response, ResponseData},
12        NetworkProvider,
13    },
14    utils::url_path_has_suffix,
15    StatsigErr,
16};
17
18use crate::networking::proxy_config::ProxyConfig;
19use reqwest::Method;
20
21const TAG: &str = "NetworkProviderReqwest";
22const LOG_EVENT_REUSE_PATH: &[&str] = &["v1", "log_event"];
23const SDK_EXCEPTION_REUSE_PATH: &[&str] = &["v1", "sdk_exception"];
24
25pub struct NetworkProviderReqwest {
26    has_file_write_access: bool,
27    shared_client: reqwest::Client,
28}
29
30impl NetworkProviderReqwest {
31    pub fn new() -> Self {
32        Self {
33            has_file_write_access: tempfile::tempfile().is_ok(),
34            shared_client: reqwest::Client::new(),
35        }
36    }
37}
38
39impl Default for NetworkProviderReqwest {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45#[async_trait]
46impl NetworkProvider for NetworkProviderReqwest {
47    async fn send(&self, method: &HttpMethod, args: &RequestArgs) -> Response {
48        if let Some(is_shutdown) = &args.is_shutdown {
49            if is_shutdown.load(std::sync::atomic::Ordering::SeqCst) {
50                return Response {
51                    status_code: None,
52                    data: None,
53                    error: Some("Request was shutdown".to_string()),
54                };
55            }
56        }
57
58        let request = self.build_request(method, args);
59
60        let mut error = None;
61        let mut status_code = None;
62        let mut data = None;
63
64        match request.send().await {
65            Ok(response) => {
66                status_code = Some(response.status().as_u16());
67
68                let data_result =
69                    if !self.has_file_write_access || args.disable_file_streaming == Some(true) {
70                        Self::write_response_to_in_memory_buffer(response).await
71                    } else {
72                        Self::write_response_to_temp_file(response).await
73                    };
74
75                match data_result {
76                    Ok(response_data) => data = Some(response_data),
77                    Err(e) => {
78                        error = Some(e.to_string());
79                    }
80                }
81            }
82            Err(e) => {
83                let error_message = get_error_message(e);
84                error = Some(error_message);
85            }
86        }
87
88        Response {
89            status_code,
90            data,
91            error,
92        }
93    }
94}
95
96impl NetworkProviderReqwest {
97    fn build_request(
98        &self,
99        method: &HttpMethod,
100        request_args: &RequestArgs,
101    ) -> reqwest::RequestBuilder {
102        let method_actual = match method {
103            HttpMethod::GET => Method::GET,
104            HttpMethod::POST => Method::POST,
105        };
106        let is_post = method_actual == Method::POST;
107
108        let client = self.get_client(request_args);
109
110        let mut request = client.request(method_actual, &request_args.url);
111
112        let timeout_duration = match request_args.timeout_ms > 0 {
113            true => Duration::from_millis(request_args.timeout_ms),
114            false => Duration::from_secs(10),
115        };
116        request = request.timeout(timeout_duration);
117
118        if let Some(headers) = &request_args.headers {
119            for (key, value) in headers {
120                request = request.header(key, value);
121            }
122        }
123
124        if let Some(params) = &request_args.query_params {
125            request = request.query(params);
126        }
127
128        if is_post {
129            let bytes = match &request_args.body {
130                Some(b) => b.clone(),
131                None => vec![],
132            };
133            let byte_len = bytes.len();
134
135            request = request.body(bytes);
136            request = request.header("Content-Length", byte_len.to_string());
137        }
138
139        request
140    }
141
142    fn get_client(&self, request_args: &RequestArgs) -> reqwest::Client {
143        if !self.should_use_shared_client(request_args) {
144            return Self::build_client(request_args);
145        }
146
147        self.shared_client.clone()
148    }
149
150    fn should_use_shared_client(&self, request_args: &RequestArgs) -> bool {
151        (request_args.log_event_connection_reuse && is_log_event_endpoint(&request_args.url))
152            || is_sdk_exception_endpoint(&request_args.url)
153    }
154
155    fn build_client(request_args: &RequestArgs) -> reqwest::Client {
156        let mut client_builder = reqwest::Client::builder();
157
158        // configure proxy if available
159        if let Some(proxy_config) = request_args.proxy_config.as_ref() {
160            client_builder = Self::configure_proxy(client_builder, proxy_config);
161        }
162
163        if let Some(ca_cert_pem) = &request_args.ca_cert_pem {
164            match reqwest::Certificate::from_pem(ca_cert_pem) {
165                Ok(cert) => {
166                    client_builder = client_builder.add_root_certificate(cert);
167                }
168                Err(e) => {
169                    log_e!(TAG, "Failed to parse network CA cert PEM: {}", e);
170                }
171            }
172        }
173
174        client_builder.build().unwrap_or_else(|e| {
175            log_e!(TAG, "Failed to build reqwest client with proxy config: {}. Falling back to default client.", e);
176            reqwest::Client::new()
177        })
178    }
179
180    fn configure_proxy(
181        client_builder: reqwest::ClientBuilder,
182        proxy_config: &ProxyConfig,
183    ) -> reqwest::ClientBuilder {
184        let (Some(host), Some(port)) = (&proxy_config.proxy_host, &proxy_config.proxy_port) else {
185            return client_builder;
186        };
187
188        let proxy_url = format!(
189            "{}://{}:{}",
190            proxy_config.proxy_protocol.as_deref().unwrap_or("http"),
191            host,
192            port
193        );
194
195        let Ok(proxy) = reqwest::Proxy::all(&proxy_url) else {
196            log_w!(TAG, "Failed to create proxy for URL: {}", proxy_url);
197            return client_builder;
198        };
199
200        let Some(auth) = &proxy_config.proxy_auth else {
201            return client_builder.proxy(proxy);
202        };
203
204        let Some((username, password)) = auth.split_once(':') else {
205            log_w!(
206                TAG,
207                "Invalid proxy auth format. Expected 'username:password'"
208            );
209            return client_builder.proxy(proxy);
210        };
211
212        client_builder.proxy(proxy.basic_auth(username, password))
213    }
214
215    async fn write_response_to_temp_file(
216        response: reqwest::Response,
217    ) -> Result<ResponseData, StatsigErr> {
218        let headers = get_response_headers(&response);
219        let mut response = response;
220        let mut temp_file = tempfile::spooled_tempfile(1024 * 1024 * 2); // 2MB
221
222        let mut total_bytes = 0;
223        while let Some(item) = response
224            .chunk()
225            .await
226            .map_err(|e| StatsigErr::FileError(e.to_string()))?
227        {
228            total_bytes += item.len();
229            temp_file
230                .write_all(&item)
231                .map_err(|e| StatsigErr::FileError(e.to_string()))?;
232        }
233
234        temp_file
235            .seek(SeekFrom::Start(0))
236            .map_err(|e| StatsigErr::FileError(e.to_string()))?;
237
238        let reader = BufReader::new(temp_file);
239
240        log_d!(TAG, "Wrote {} bytes to spooled temp file", total_bytes);
241
242        Ok(ResponseData::from_stream_with_headers(
243            Box::new(reader),
244            headers,
245        ))
246    }
247
248    async fn write_response_to_in_memory_buffer(
249        response: reqwest::Response,
250    ) -> Result<ResponseData, StatsigErr> {
251        let headers = get_response_headers(&response);
252        let bytes = response
253            .bytes()
254            .await
255            .map_err(|e| StatsigErr::SerializationError(e.to_string()))?;
256
257        log_d!(TAG, "Wrote {} bytes to in-memory buffer", bytes.len());
258
259        Ok(ResponseData::from_bytes_with_headers(
260            bytes.to_vec(),
261            headers,
262        ))
263    }
264}
265
266fn get_error_message(error: reqwest::Error) -> String {
267    let mut error_message = error.to_string();
268
269    if let Some(url_error) = error.url() {
270        error_message.push_str(&format!(". URL: {}", url_error));
271    }
272
273    if let Some(status_error) = error.status() {
274        error_message.push_str(&format!(". Status: {}", status_error));
275    }
276
277    error_message
278}
279
280fn is_log_event_endpoint(url: &str) -> bool {
281    url_path_has_suffix(url, LOG_EVENT_REUSE_PATH)
282}
283
284fn is_sdk_exception_endpoint(url: &str) -> bool {
285    url_path_has_suffix(url, SDK_EXCEPTION_REUSE_PATH)
286}
287
288fn get_response_headers(response: &reqwest::Response) -> Option<HashMap<String, String>> {
289    let headers = response.headers();
290    if headers.is_empty() {
291        return None;
292    }
293
294    let mut headers_map = HashMap::new();
295    for (key, value) in headers {
296        headers_map.insert(key.to_string(), value.to_str().unwrap_or("").to_string());
297    }
298
299    Some(headers_map)
300}
301
302#[cfg(test)]
303mod tests {
304    use super::{is_log_event_endpoint, is_sdk_exception_endpoint};
305
306    #[test]
307    fn test_is_log_event_endpoint_matches_exact_suffix() {
308        assert!(is_log_event_endpoint(
309            "https://api.statsig.com/v1/log_event"
310        ));
311        assert!(is_log_event_endpoint(
312            "https://api.statsig.com/v1/log_event/"
313        ));
314        assert!(is_log_event_endpoint(
315            "https://api.statsig.com/prefix/v1/log_event?foo=bar"
316        ));
317
318        assert!(!is_log_event_endpoint(
319            "https://api.statsig.com/v1/log_event/extra"
320        ));
321        assert!(!is_log_event_endpoint(
322            "https://api.statsig.com/v1/log_events"
323        ));
324        assert!(!is_log_event_endpoint("https://api.statsig.com/log_event"));
325    }
326
327    #[test]
328    fn test_is_sdk_exception_endpoint_matches_exact_suffix() {
329        assert!(is_sdk_exception_endpoint(
330            "https://api.statsig.com/v1/sdk_exception"
331        ));
332        assert!(is_sdk_exception_endpoint(
333            "https://api.statsig.com/prefix/v1/sdk_exception#frag"
334        ));
335
336        assert!(!is_sdk_exception_endpoint(
337            "https://api.statsig.com/v1/sdk_exception/extra"
338        ));
339        assert!(!is_sdk_exception_endpoint(
340            "https://api.statsig.com/v1/sdk_exceptions"
341        ));
342    }
343}