statsig_rust/networking/providers/
curl.rs1use crate::networking::{HttpMethod, NetworkProvider, RequestArgs, Response};
2use crate::observability::util::sanitize_url_for_logging;
3use crate::{log_d, log_e, ok_or_return_with, unwrap_or_return_with, StatsigErr};
4use async_trait::async_trait;
5use chrono::Utc;
6use curl::easy::{Easy2, Handler, List, WriteError};
7use curl::multi::Easy2Handle;
8use curl::multi::{self, Multi};
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11use std::thread::{self, JoinHandle};
12use std::time::Duration;
13use tokio::sync::{mpsc, oneshot};
14use tokio::{runtime, time};
15
16const MAX_QUEUED_REQUESTS: usize = 10;
17
18lazy_static::lazy_static! {
19 static ref CURL: Mutex<HashMap<String, Arc<CurlContext>>> = Mutex::new(HashMap::new());
20}
21
22struct Request {
23 method: HttpMethod,
24 args: RequestArgs,
25 tx: oneshot::Sender<Result<Response, StatsigErr>>,
26}
27
28struct ActiveRequest {
29 request: Request,
30 handle: Easy2Handle<Collector>,
31}
32
33struct CurlContext {
34 req_tx: mpsc::Sender<Request>,
35 _abort_tx: Option<oneshot::Sender<()>>,
36 _handle: Option<Arc<JoinHandle<()>>>,
37}
38
39const TAG: &str = stringify!(Curl);
40
41pub struct Curl {
42 sdk_key: String,
43 context: Arc<CurlContext>,
44}
45
46impl Drop for Curl {
47 fn drop(&mut self) {
48 let count = Arc::strong_count(&self.context);
49
50 if count <= 2 {
51 if let Ok(mut curl_map) = CURL.lock() {
52 curl_map.remove(&self.sdk_key);
53 }
54 }
55 }
56}
57
58#[async_trait]
59impl NetworkProvider for Curl {
60 async fn send(&self, method: &HttpMethod, request_args: &RequestArgs) -> Response {
61 let method_name = if method == &HttpMethod::POST {
62 "POST"
63 } else {
64 "GET"
65 };
66 log_d!(TAG, "Sending {} Request: {}", method_name, request_args.url);
67
68 if let Some(headers) = &request_args.headers {
69 for (key, value) in headers {
70 log_d!(TAG, "Header: {} = {}", key, value);
71 }
72 }
73
74 let (response_tx, response_rx) = oneshot::channel();
75 let request = Request {
76 method: method.clone(),
77 args: request_args.clone(),
78 tx: response_tx,
79 };
80
81 match self.context.req_tx.send(request).await {
82 Ok(()) => (),
83 Err(e) => {
84 return Response {
85 status_code: 0,
86 data: None,
87 error: Some(e.to_string()),
88 }
89 }
90 }
91
92 let result = response_rx.await.unwrap_or_else(|e| {
93 log_e!(TAG, "Failed to receive response: {:?}", e);
94 Err(StatsigErr::NetworkError(e.to_string()))
95 });
96
97 result.unwrap_or_else(|e| Response {
98 status_code: 0,
99 data: None,
100 error: Some(e.to_string()),
101 })
102 }
103}
104
105impl Curl {
106 #[must_use]
107 pub fn get_instance(sdk_key: &str) -> Self {
108 let mut curl_map = match CURL.lock() {
109 Ok(map) => map,
110 Err(e) => {
111 log_e!(TAG, "Failed to acquire lock on CURL: {}", e);
112 return Curl::new(sdk_key);
113 }
114 };
115
116 if let Some(curl) = curl_map.get(sdk_key) {
117 Curl {
118 sdk_key: sdk_key.to_string(),
119 context: curl.clone(),
120 }
121 } else {
122 let curl = Curl::new(sdk_key);
123 curl_map.insert(sdk_key.to_string(), curl.context.clone());
124 curl
125 }
126 }
127
128 fn new(sdk_key: &str) -> Curl {
129 let (handle, abort_tx, req_tx) = Self::create_run_loop();
130
131 Curl {
132 sdk_key: sdk_key.to_string(),
133 context: Arc::new(CurlContext {
134 req_tx,
135 _abort_tx: Some(abort_tx),
136 _handle: handle.map(Arc::new),
137 }),
138 }
139 }
140
141 fn create_run_loop() -> (
142 Option<JoinHandle<()>>,
143 oneshot::Sender<()>,
144 mpsc::Sender<Request>,
145 ) {
146 let (abort_tx, abort_rx) = oneshot::channel::<()>();
147 let (req_tx, req_rx) = mpsc::channel::<Request>(MAX_QUEUED_REQUESTS);
148
149 let handle_result = thread::Builder::new()
150 .name("curl-run-loop".to_string())
151 .spawn(move || {
152 let rt = match runtime::Builder::new_current_thread().enable_all().build() {
153 Ok(rt) => rt,
154 Err(e) => {
155 log_e!(TAG, "Failed to build cURL runtime: {:?}", e);
156 return;
157 }
158 };
159
160 rt.block_on(Self::run(abort_rx, req_rx));
161 });
162
163 let handle = match handle_result {
164 Ok(handle) => {
165 log_d!(TAG, "New cURL run loop created.");
166 Some(handle)
167 }
168 Err(e) => {
169 log_e!(TAG, "Failed to spawn cURL run loop: {:?}", e);
170 None
171 }
172 };
173
174 (handle, abort_tx, req_tx)
175 }
176
177 async fn run(mut abort_rx: oneshot::Receiver<()>, mut req_rx: mpsc::Receiver<Request>) {
178 let multi = Multi::new();
179 let mut active_reqs = HashMap::new();
180 let mut next_token = 0;
181
182 loop {
183 tokio::select! {
184 _ = &mut abort_rx => {
185 break;
186 }
187 () = time::sleep(Duration::from_millis(1)), if !active_reqs.is_empty() => {}
188 Some(request) = req_rx.recv() => {
189 if active_reqs.is_empty() {
190 next_token = 0;
191 }
192
193 if let Err(e) = Self::add_request_for_processing(&multi, &mut active_reqs, &mut next_token, request) {
194 log_e!(TAG, "Failed to add request for processing: {:?}", e);
195 }
196 }
197 }
198
199 Self::remove_shutdown_requests(&multi, &mut active_reqs);
200 Self::process_active_requests(&multi, &mut active_reqs);
201 }
202 }
203
204 fn add_request_for_processing(
205 multi: &Multi,
206 handles: &mut HashMap<usize, ActiveRequest>,
207 next_token: &mut usize,
208 request: Request,
209 ) -> Result<(), StatsigErr> {
210 let args = &request.args;
211 let easy = construct_easy_request(&request.method, args)
212 .map_err(|e| StatsigErr::NetworkError(e.to_string()))?;
213
214 match multi.add2(easy) {
215 Ok(mut handle) => {
216 handle
217 .set_token(*next_token)
218 .map_err(|e| StatsigErr::NetworkError(e.to_string()))?;
219 handles.insert(*next_token, ActiveRequest { request, handle });
220 *next_token = next_token.wrapping_add(1);
221 Ok(())
222 }
223 Err(e) => Err(StatsigErr::NetworkError(e.to_string())),
224 }
225 }
226
227 fn remove_shutdown_requests(multi: &multi::Multi, active: &mut HashMap<usize, ActiveRequest>) {
228 let to_remove: Vec<usize> = active
229 .iter()
230 .filter_map(|(token, entry)| {
231 if let Some(is_shutdown) = &entry.request.args.is_shutdown {
232 if is_shutdown.load(std::sync::atomic::Ordering::SeqCst) {
233 return Some(*token);
234 }
235 }
236 None
237 })
238 .collect();
239
240 for token in to_remove {
241 if let Some(entry) = active.remove(&token) {
242 let _ = entry.request.tx.send(Err(StatsigErr::NetworkError(
243 "Request was shutdown".to_string(),
244 )));
245
246 if let Err(e) = multi.remove2(entry.handle) {
247 log_e!(TAG, "Failed to remove request from multi: {:?}", e);
248 }
249 }
250 }
251 }
252
253 fn process_active_requests(multi: &Multi, active: &mut HashMap<usize, ActiveRequest>) {
254 let perform = match multi.perform() {
255 Ok(perform) => perform,
256 Err(e) => {
257 log_e!(TAG, "Failed to perform requests: {:?}", e);
258 return;
259 }
260 };
261
262 if perform == 0 {
263 log_d!(TAG, "No requests performed");
264 }
265
266 multi.messages(|msg| {
267 let token = ok_or_return_with!(msg.token(), |e| {
268 log_e!(TAG, "Failed to get token: {:?}", e);
269 });
270
271 let mut entry = unwrap_or_return_with!(active.remove(&token), || {
272 log_e!(TAG, "Token not found: {}", token);
273 });
274 let url = &entry.request.args.url;
275
276 let result = unwrap_or_return_with!(msg.result_for2(&entry.handle), || {
277 log_e!(TAG, "Failed to get result for token: {}", token);
278 });
279 let sanitized_url = sanitize_url_for_logging(url);
280
281 match result {
282 Ok(()) => {
283 let http_status = entry.handle.response_code().unwrap_or_else(|e| {
284 log_e!(TAG, "Failed to get HTTP status: {:?}", e);
285 0
286 });
287
288 let res_buffer = entry.handle.get_mut().get_buffer();
289 log_d!(
290 TAG,
291 "Transfer succeeded (Status: {}) (Download length: {}) {}",
292 http_status,
293 &res_buffer.len(),
294 sanitized_url
295 );
296
297 let data = String::from_utf8(res_buffer)
298 .map_err(|e| {
299 log_e!(
300 TAG,
301 "Failed to convert response to string: {} {:?}",
302 sanitized_url,
303 e
304 );
305 e
306 })
307 .ok();
308
309 let response = Response {
310 data,
311 status_code: http_status as u16,
312 error: None,
313 };
314
315 if entry.request.tx.send(Ok(response)).is_err() {
316 log_e!(TAG, "Failed to broadcast response: {}", sanitized_url);
317 }
318 }
319 Err(e) => {
320 log_e!(TAG, "Failed to send request to {}: {:?}", sanitized_url, e);
321 let _ = entry
322 .request
323 .tx
324 .send(Err(StatsigErr::NetworkError(e.to_string())));
325 return;
326 }
327 };
328
329 if let Err(e) = multi.remove2(entry.handle) {
330 log_e!(TAG, "Failed to remove request from multi: {:?}", e);
331 }
332
333 log_d!(TAG, "Request completed: {}", sanitized_url);
334 });
335 }
336}
337
338fn construct_easy_request(
339 method: &HttpMethod,
340 args: &RequestArgs,
341) -> Result<Easy2<Collector>, curl::Error> {
342 let mut easy = Easy2::new(Collector::new());
343
344 if args.timeout_ms > 0 {
345 easy.timeout(Duration::from_millis(args.timeout_ms))?;
346 } else {
347 easy.timeout(Duration::from_secs(10))?;
348 }
349
350 if args.accept_gzip_response {
351 easy.accept_encoding("gzip")?;
352 }
353
354 if *method == HttpMethod::POST {
355 easy.post(true)?;
356 }
357
358 let mut headers = List::new();
359
360 headers.append(&format!(
361 "statsig-client-time: {}",
362 Utc::now().timestamp_millis()
363 ))?;
364
365 if let Some(body) = &args.body {
366 easy.post_fields_copy(body)?;
367 headers.append("Content-Type: application/json")?;
368 }
369
370 if let Some(additional_headers) = &args.headers {
371 for (key, value) in additional_headers {
372 headers.append(&format!("{key}: {value}"))?;
373 }
374 }
375 easy.http_headers(headers)?;
376
377 if let Some(params) = &args.query_params {
378 let query_string = params
379 .iter()
380 .map(|(k, v)| format!("{k}={v}"))
381 .collect::<Vec<_>>()
382 .join("&");
383 easy.url(&format!("{}?{}", args.url, query_string))?;
384 } else {
385 easy.url(&args.url)?;
386 }
387
388 Ok(easy)
389}
390
391struct Collector {
392 buffer: Vec<u8>,
393}
394
395impl Collector {
396 fn new() -> Self {
397 Self { buffer: Vec::new() }
398 }
399
400 fn get_buffer(&mut self) -> Vec<u8> {
401 std::mem::take(&mut self.buffer)
402 }
403}
404
405impl Handler for Collector {
406 fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
407 self.buffer.extend_from_slice(data);
408 Ok(data.len())
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use crate::Statsig;
415
416 use super::*;
417 use more_asserts::assert_le;
418 use std::sync::atomic::AtomicBool;
419 use std::time::Instant;
420 use tokio::task;
421 use wiremock::{
422 http::Method as WiremockMethod,
423 matchers::{method, path},
424 Mock, MockServer, ResponseTemplate,
425 };
426
427 #[test]
428 fn test_only_one_instance() {
429 let key = "key_1";
430 let curl_service_1 = Curl::get_instance(key);
431 let curl_service_2 = Curl::get_instance(key);
432
433 assert!(Arc::ptr_eq(
434 &curl_service_1.context,
435 &curl_service_2.context
436 ));
437 }
438
439 #[test]
440 fn test_creating_multiples() {
441 let key = "key_2";
442
443 let mut last = 0;
444 for _ in 0..10 {
445 assert!(CURL.lock().unwrap().get(key).is_none());
446 let c = Curl::get_instance(key);
447 let now = Arc::as_ptr(&c.context) as usize;
448
449 assert_ne!(now, last);
450 last = now;
451 assert!(CURL.lock().unwrap().get(key).is_some());
452 }
453
454 assert!(CURL.lock().unwrap().get(key).is_none());
455 }
456
457 #[test]
458 fn test_drop_releases_instance() {
459 let key = "key_3";
460
461 let curl_service_1 = Curl::get_instance(key);
462 let curl_service_2 = Curl::get_instance(key);
463 assert!(CURL.lock().unwrap().get(key).is_some());
464
465 drop(curl_service_1);
466 assert!(CURL.lock().unwrap().get(key).is_some());
467
468 drop(curl_service_2);
469 assert!(CURL.lock().unwrap().get(key).is_none());
470 }
471
472 #[tokio::test]
473 async fn test_shutdown_kills_requests() {
474 let key = "key_4";
475
476 let server = MockServer::start().await;
477
478 Mock::given(method(WiremockMethod::GET))
479 .and(path("/test"))
480 .respond_with(
481 ResponseTemplate::new(200)
482 .set_body_string("{\"success\": true}")
483 .set_delay(Duration::from_millis(10_000)),
484 )
485 .mount(&server)
486 .await;
487
488 let shutdown = Arc::new(AtomicBool::new(false));
489 let shutdown_clone = shutdown.clone();
490 let handle = task::spawn(async move {
491 let curl = Curl::get_instance(key);
492 curl.send(
493 &HttpMethod::GET,
494 &RequestArgs {
495 is_shutdown: Some(shutdown_clone),
496 url: format!("{}/test", server.uri()),
497 ..RequestArgs::new()
498 },
499 )
500 .await;
501 });
502
503 let start = Instant::now();
504 shutdown.store(true, std::sync::atomic::Ordering::SeqCst);
505 handle.await.unwrap();
506
507 assert_le!(start.elapsed().as_millis(), 100);
508 time::sleep(Duration::from_millis(100)).await;
509 assert!(CURL.lock().unwrap().get(key).is_none());
510 }
511
512 #[tokio::test]
513 async fn test_statsig_shutdown_kills_thread() {
514 let key = "sdk_key_5";
515 let statsig = Statsig::new(key, None);
516
517 let _ = statsig.initialize().await;
518 assert!(CURL.lock().unwrap().get(key).is_some());
519
520 let _ = statsig.shutdown().await;
521 drop(statsig);
522
523 tokio::time::sleep(Duration::from_millis(1)).await;
524 assert!(CURL.lock().unwrap().get(key).is_none());
525 }
526
527 #[tokio::test]
528 async fn test_thread_dies_on_drop() {
529 let key = "sdk_key_6";
530 let curl = Curl::get_instance(key);
531 let handle = curl.context._handle.clone();
532
533 assert!(!handle.as_ref().is_some_and(|h| h.is_finished()));
534 drop(curl);
535
536 tokio::time::sleep(Duration::from_millis(100)).await;
537 assert!(handle.as_ref().is_some_and(|h| h.is_finished()));
538 }
539}