statsig_rust/networking/providers/
net_provider_reqwest.rs

1use std::collections::HashMap;
2use std::io::{BufReader, Seek, SeekFrom, Write};
3use std::time::Duration;
4
5use async_trait::async_trait;
6
7use crate::{
8    log_e, log_w,
9    networking::{
10        http_types::{HttpMethod, RequestArgs, Response, ResponseData},
11        NetworkProvider,
12    },
13    StatsigErr,
14};
15
16use crate::networking::proxy_config::ProxyConfig;
17use reqwest::Method;
18
19const TAG: &str = "NetworkProviderReqwest";
20
21pub struct NetworkProviderReqwest {}
22
23#[async_trait]
24impl NetworkProvider for NetworkProviderReqwest {
25    async fn send(&self, method: &HttpMethod, args: &RequestArgs) -> Response {
26        if let Some(is_shutdown) = &args.is_shutdown {
27            if is_shutdown.load(std::sync::atomic::Ordering::SeqCst) {
28                return Response {
29                    status_code: None,
30                    data: None,
31                    error: Some("Request was shutdown".to_string()),
32                    headers: None,
33                };
34            }
35        }
36
37        let request = self.build_request(method, args);
38
39        let mut error = None;
40        let mut status_code = None;
41        let mut data = None;
42        let mut headers = None;
43
44        match request.send().await {
45            Ok(response) => {
46                status_code = Some(response.status().as_u16());
47                headers = get_response_headers(&response);
48
49                match Self::write_response_to_temp_file(response).await {
50                    Ok(response_data) => data = Some(response_data),
51                    Err(e) => {
52                        error = Some(e.to_string());
53                    }
54                }
55            }
56            Err(e) => {
57                let error_message = get_error_message(e);
58                error = Some(error_message);
59            }
60        }
61
62        Response {
63            status_code,
64            data,
65            error,
66            headers,
67        }
68    }
69}
70
71impl NetworkProviderReqwest {
72    fn build_request(
73        &self,
74        method: &HttpMethod,
75        request_args: &RequestArgs,
76    ) -> reqwest::RequestBuilder {
77        let method_actual = match method {
78            HttpMethod::GET => Method::GET,
79            HttpMethod::POST => Method::POST,
80        };
81        let is_post = method_actual == Method::POST;
82
83        let mut client_builder = reqwest::Client::builder();
84
85        // configure proxy if available
86        if let Some(proxy_config) = request_args.proxy_config.as_ref() {
87            client_builder = Self::configure_proxy(client_builder, proxy_config);
88        }
89
90        let client = client_builder.build().unwrap_or_else(|e| {
91            log_e!(TAG, "Failed to build reqwest client with proxy config: {}. Falling back to default client.", e);
92            reqwest::Client::new()
93        });
94
95        let mut request = client.request(method_actual, &request_args.url);
96
97        let timeout_duration = match request_args.timeout_ms > 0 {
98            true => Duration::from_millis(request_args.timeout_ms),
99            false => Duration::from_secs(10),
100        };
101        request = request.timeout(timeout_duration);
102
103        if let Some(headers) = &request_args.headers {
104            for (key, value) in headers {
105                request = request.header(key, value);
106            }
107        }
108
109        if let Some(params) = &request_args.query_params {
110            request = request.query(params);
111        }
112
113        if is_post {
114            let bytes = match &request_args.body {
115                Some(b) => b.clone(),
116                None => vec![],
117            };
118            let byte_len = bytes.len();
119
120            request = request.body(bytes);
121            request = request.header("Content-Length", byte_len.to_string());
122        }
123
124        request
125    }
126
127    fn configure_proxy(
128        client_builder: reqwest::ClientBuilder,
129        proxy_config: &ProxyConfig,
130    ) -> reqwest::ClientBuilder {
131        let (Some(host), Some(port)) = (&proxy_config.proxy_host, &proxy_config.proxy_port) else {
132            return client_builder;
133        };
134
135        let proxy_url = format!(
136            "{}://{}:{}",
137            proxy_config.proxy_protocol.as_deref().unwrap_or("http"),
138            host,
139            port
140        );
141
142        let Ok(proxy) = reqwest::Proxy::all(&proxy_url) else {
143            log_w!(TAG, "Failed to create proxy for URL: {}", proxy_url);
144            return client_builder;
145        };
146
147        let Some(auth) = &proxy_config.proxy_auth else {
148            return client_builder.proxy(proxy);
149        };
150
151        let Some((username, password)) = auth.split_once(':') else {
152            log_w!(
153                TAG,
154                "Invalid proxy auth format. Expected 'username:password'"
155            );
156            return client_builder.proxy(proxy);
157        };
158
159        client_builder.proxy(proxy.basic_auth(username, password))
160    }
161
162    async fn write_response_to_temp_file(
163        response: reqwest::Response,
164    ) -> Result<ResponseData, StatsigErr> {
165        let mut response = response;
166        let mut temp_file = tempfile::spooled_tempfile(1024 * 1024 * 2); // 2MB
167
168        while let Some(item) = response
169            .chunk()
170            .await
171            .map_err(|e| StatsigErr::FileError(e.to_string()))?
172        {
173            temp_file
174                .write_all(&item)
175                .map_err(|e| StatsigErr::FileError(e.to_string()))?;
176        }
177
178        temp_file
179            .seek(SeekFrom::Start(0))
180            .map_err(|e| StatsigErr::FileError(e.to_string()))?;
181
182        let reader = BufReader::new(temp_file);
183        Ok(ResponseData::from_stream(Box::new(reader)))
184    }
185}
186
187fn get_error_message(error: reqwest::Error) -> String {
188    let mut error_message = error.to_string();
189
190    if let Some(url_error) = error.url() {
191        error_message.push_str(&format!(". URL: {}", url_error));
192    }
193
194    if let Some(status_error) = error.status() {
195        error_message.push_str(&format!(". Status: {}", status_error));
196    }
197
198    error_message
199}
200
201fn get_response_headers(response: &reqwest::Response) -> Option<HashMap<String, String>> {
202    let headers = response.headers();
203    if headers.is_empty() {
204        return None;
205    }
206
207    let mut headers_map = HashMap::new();
208    for (key, value) in headers {
209        headers_map.insert(key.to_string(), value.to_str().unwrap_or("").to_string());
210    }
211
212    Some(headers_map)
213}