statsig_rust/networking/providers/
curl.rs

1use crate::networking::{HttpMethod, NetworkProvider, RequestArgs, Response};
2use crate::observability::util::sanitize_url_for_logging;
3use crate::{log_d, log_e, ok_or_return_with, unwrap_or_return_with, StatsigErr};
4use async_trait::async_trait;
5use chrono::Utc;
6use curl::easy::{Easy2, Handler, List, WriteError};
7use curl::multi::Easy2Handle;
8use curl::multi::{self, Multi};
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11use std::thread::{self, JoinHandle};
12use std::time::Duration;
13use tokio::sync::{mpsc, oneshot};
14use tokio::{runtime, time};
15
16const MAX_QUEUED_REQUESTS: usize = 10;
17
18lazy_static::lazy_static! {
19    static ref CURL: Mutex<HashMap<String, Arc<CurlContext>>> = Mutex::new(HashMap::new());
20}
21
22struct Request {
23    method: HttpMethod,
24    args: RequestArgs,
25    tx: oneshot::Sender<Result<Response, StatsigErr>>,
26}
27
28struct ActiveRequest {
29    request: Request,
30    handle: Easy2Handle<Collector>,
31}
32
33struct CurlContext {
34    req_tx: mpsc::Sender<Request>,
35    _abort_tx: Option<oneshot::Sender<()>>,
36    _handle: Option<Arc<JoinHandle<()>>>,
37}
38
39const TAG: &str = stringify!(Curl);
40
41pub struct Curl {
42    sdk_key: String,
43    context: Arc<CurlContext>,
44}
45
46impl Drop for Curl {
47    fn drop(&mut self) {
48        let count = Arc::strong_count(&self.context);
49
50        if count <= 2 {
51            if let Ok(mut curl_map) = CURL.lock() {
52                curl_map.remove(&self.sdk_key);
53            }
54        }
55    }
56}
57
58#[async_trait]
59impl NetworkProvider for Curl {
60    async fn send(&self, method: &HttpMethod, request_args: &RequestArgs) -> Response {
61        let method_name = if method == &HttpMethod::POST {
62            "POST"
63        } else {
64            "GET"
65        };
66        log_d!(TAG, "Sending {} Request: {}", method_name, request_args.url);
67
68        if let Some(headers) = &request_args.headers {
69            for (key, value) in headers {
70                log_d!(TAG, "Header: {} = {}", key, value);
71            }
72        }
73
74        let (response_tx, response_rx) = oneshot::channel();
75        let request = Request {
76            method: method.clone(),
77            args: request_args.clone(),
78            tx: response_tx,
79        };
80
81        match self.context.req_tx.send(request).await {
82            Ok(()) => (),
83            Err(e) => {
84                return Response {
85                    status_code: 0,
86                    data: None,
87                    error: Some(e.to_string()),
88                }
89            }
90        }
91
92        let result = response_rx.await.unwrap_or_else(|e| {
93            log_e!(TAG, "Failed to receive response: {:?}", e);
94            Err(StatsigErr::NetworkError(e.to_string()))
95        });
96
97        result.unwrap_or_else(|e| Response {
98            status_code: 0,
99            data: None,
100            error: Some(e.to_string()),
101        })
102    }
103}
104
105impl Curl {
106    #[must_use]
107    pub fn get_instance(sdk_key: &str) -> Self {
108        let mut curl_map = match CURL.lock() {
109            Ok(map) => map,
110            Err(e) => {
111                log_e!(TAG, "Failed to acquire lock on CURL: {}", e);
112                return Curl::new(sdk_key);
113            }
114        };
115
116        if let Some(curl) = curl_map.get(sdk_key) {
117            Curl {
118                sdk_key: sdk_key.to_string(),
119                context: curl.clone(),
120            }
121        } else {
122            let curl = Curl::new(sdk_key);
123            curl_map.insert(sdk_key.to_string(), curl.context.clone());
124            curl
125        }
126    }
127
128    fn new(sdk_key: &str) -> Curl {
129        let (handle, abort_tx, req_tx) = Self::create_run_loop();
130
131        Curl {
132            sdk_key: sdk_key.to_string(),
133            context: Arc::new(CurlContext {
134                req_tx,
135                _abort_tx: Some(abort_tx),
136                _handle: handle.map(Arc::new),
137            }),
138        }
139    }
140
141    fn create_run_loop() -> (
142        Option<JoinHandle<()>>,
143        oneshot::Sender<()>,
144        mpsc::Sender<Request>,
145    ) {
146        let (abort_tx, abort_rx) = oneshot::channel::<()>();
147        let (req_tx, req_rx) = mpsc::channel::<Request>(MAX_QUEUED_REQUESTS);
148
149        let handle_result = thread::Builder::new()
150            .name("curl-run-loop".to_string())
151            .spawn(move || {
152                let rt = match runtime::Builder::new_current_thread().enable_all().build() {
153                    Ok(rt) => rt,
154                    Err(e) => {
155                        log_e!(TAG, "Failed to build cURL runtime: {:?}", e);
156                        return;
157                    }
158                };
159
160                rt.block_on(Self::run(abort_rx, req_rx));
161            });
162
163        let handle = match handle_result {
164            Ok(handle) => {
165                log_d!(TAG, "New cURL run loop created.");
166                Some(handle)
167            }
168            Err(e) => {
169                log_e!(TAG, "Failed to spawn cURL run loop: {:?}", e);
170                None
171            }
172        };
173
174        (handle, abort_tx, req_tx)
175    }
176
177    async fn run(mut abort_rx: oneshot::Receiver<()>, mut req_rx: mpsc::Receiver<Request>) {
178        let multi = Multi::new();
179        let mut active_reqs = HashMap::new();
180        let mut next_token = 0;
181
182        loop {
183            tokio::select! {
184                _ = &mut abort_rx => {
185                    break;
186                }
187                () = time::sleep(Duration::from_millis(1)), if !active_reqs.is_empty() => {}
188                Some(request) = req_rx.recv() => {
189                    if active_reqs.is_empty() {
190                        next_token = 0;
191                    }
192
193                    if let Err(e) = Self::add_request_for_processing(&multi, &mut active_reqs, &mut next_token, request) {
194                        log_e!(TAG, "Failed to add request for processing: {:?}", e);
195                    }
196                }
197            }
198
199            Self::remove_shutdown_requests(&multi, &mut active_reqs);
200            Self::process_active_requests(&multi, &mut active_reqs);
201        }
202    }
203
204    fn add_request_for_processing(
205        multi: &Multi,
206        handles: &mut HashMap<usize, ActiveRequest>,
207        next_token: &mut usize,
208        request: Request,
209    ) -> Result<(), StatsigErr> {
210        let args = &request.args;
211        let easy = construct_easy_request(&request.method, args)
212            .map_err(|e| StatsigErr::NetworkError(e.to_string()))?;
213
214        match multi.add2(easy) {
215            Ok(mut handle) => {
216                handle
217                    .set_token(*next_token)
218                    .map_err(|e| StatsigErr::NetworkError(e.to_string()))?;
219                handles.insert(*next_token, ActiveRequest { request, handle });
220                *next_token = next_token.wrapping_add(1);
221                Ok(())
222            }
223            Err(e) => Err(StatsigErr::NetworkError(e.to_string())),
224        }
225    }
226
227    fn remove_shutdown_requests(multi: &multi::Multi, active: &mut HashMap<usize, ActiveRequest>) {
228        let to_remove: Vec<usize> = active
229            .iter()
230            .filter_map(|(token, entry)| {
231                if let Some(is_shutdown) = &entry.request.args.is_shutdown {
232                    if is_shutdown.load(std::sync::atomic::Ordering::SeqCst) {
233                        return Some(*token);
234                    }
235                }
236                None
237            })
238            .collect();
239
240        for token in to_remove {
241            if let Some(entry) = active.remove(&token) {
242                let _ = entry.request.tx.send(Err(StatsigErr::NetworkError(
243                    "Request was shutdown".to_string(),
244                )));
245
246                if let Err(e) = multi.remove2(entry.handle) {
247                    log_e!(TAG, "Failed to remove request from multi: {:?}", e);
248                }
249            }
250        }
251    }
252
253    fn process_active_requests(multi: &Multi, active: &mut HashMap<usize, ActiveRequest>) {
254        let perform = match multi.perform() {
255            Ok(perform) => perform,
256            Err(e) => {
257                log_e!(TAG, "Failed to perform requests: {:?}", e);
258                return;
259            }
260        };
261
262        if perform == 0 {
263            log_d!(TAG, "No requests performed");
264        }
265
266        multi.messages(|msg| {
267            let token = ok_or_return_with!(msg.token(), |e| {
268                log_e!(TAG, "Failed to get token: {:?}", e);
269            });
270
271            let mut entry = unwrap_or_return_with!(active.remove(&token), || {
272                log_e!(TAG, "Token not found: {}", token);
273            });
274            let url = &entry.request.args.url;
275
276            let result = unwrap_or_return_with!(msg.result_for2(&entry.handle), || {
277                log_e!(TAG, "Failed to get result for token: {}", token);
278            });
279            let sanitized_url = sanitize_url_for_logging(url);
280
281            match result {
282                Ok(()) => {
283                    let http_status = entry.handle.response_code().unwrap_or_else(|e| {
284                        log_e!(TAG, "Failed to get HTTP status: {:?}", e);
285                        0
286                    });
287
288                    let res_buffer = entry.handle.get_mut().get_buffer();
289                    log_d!(
290                        TAG,
291                        "Transfer succeeded (Status: {}) (Download length: {}) {}",
292                        http_status,
293                        &res_buffer.len(),
294                        sanitized_url
295                    );
296
297                    let data = String::from_utf8(res_buffer)
298                        .map_err(|e| {
299                            log_e!(
300                                TAG,
301                                "Failed to convert response to string: {} {:?}",
302                                sanitized_url,
303                                e
304                            );
305                            e
306                        })
307                        .ok();
308
309                    let response = Response {
310                        data,
311                        status_code: http_status as u16,
312                        error: None,
313                    };
314
315                    if entry.request.tx.send(Ok(response)).is_err() {
316                        log_e!(TAG, "Failed to broadcast response: {}", sanitized_url);
317                    }
318                }
319                Err(e) => {
320                    log_e!(TAG, "Failed to send request to {}: {:?}", sanitized_url, e);
321                    let _ = entry
322                        .request
323                        .tx
324                        .send(Err(StatsigErr::NetworkError(e.to_string())));
325                    return;
326                }
327            };
328
329            if let Err(e) = multi.remove2(entry.handle) {
330                log_e!(TAG, "Failed to remove request from multi: {:?}", e);
331            }
332
333            log_d!(TAG, "Request completed: {}", sanitized_url);
334        });
335    }
336}
337
338fn construct_easy_request(
339    method: &HttpMethod,
340    args: &RequestArgs,
341) -> Result<Easy2<Collector>, curl::Error> {
342    let mut easy = Easy2::new(Collector::new());
343
344    if args.timeout_ms > 0 {
345        easy.timeout(Duration::from_millis(args.timeout_ms))?;
346    } else {
347        easy.timeout(Duration::from_secs(10))?;
348    }
349
350    if args.accept_gzip_response {
351        easy.accept_encoding("gzip")?;
352    }
353
354    if *method == HttpMethod::POST {
355        easy.post(true)?;
356    }
357
358    let mut headers = List::new();
359
360    headers.append(&format!(
361        "statsig-client-time: {}",
362        Utc::now().timestamp_millis()
363    ))?;
364
365    if let Some(body) = &args.body {
366        easy.post_fields_copy(body)?;
367        headers.append("Content-Type: application/json")?;
368    }
369
370    if let Some(additional_headers) = &args.headers {
371        for (key, value) in additional_headers {
372            headers.append(&format!("{key}: {value}"))?;
373        }
374    }
375    easy.http_headers(headers)?;
376
377    if let Some(params) = &args.query_params {
378        let query_string = params
379            .iter()
380            .map(|(k, v)| format!("{k}={v}"))
381            .collect::<Vec<_>>()
382            .join("&");
383        easy.url(&format!("{}?{}", args.url, query_string))?;
384    } else {
385        easy.url(&args.url)?;
386    }
387
388    Ok(easy)
389}
390
391struct Collector {
392    buffer: Vec<u8>,
393}
394
395impl Collector {
396    fn new() -> Self {
397        Self { buffer: Vec::new() }
398    }
399
400    fn get_buffer(&mut self) -> Vec<u8> {
401        std::mem::take(&mut self.buffer)
402    }
403}
404
405impl Handler for Collector {
406    fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
407        self.buffer.extend_from_slice(data);
408        Ok(data.len())
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use crate::Statsig;
415
416    use super::*;
417    use more_asserts::assert_le;
418    use std::sync::atomic::AtomicBool;
419    use std::time::Instant;
420    use tokio::task;
421    use wiremock::{
422        http::Method as WiremockMethod,
423        matchers::{method, path},
424        Mock, MockServer, ResponseTemplate,
425    };
426
427    #[test]
428    fn test_only_one_instance() {
429        let key = "key_1";
430        let curl_service_1 = Curl::get_instance(key);
431        let curl_service_2 = Curl::get_instance(key);
432
433        assert!(Arc::ptr_eq(
434            &curl_service_1.context,
435            &curl_service_2.context
436        ));
437    }
438
439    #[test]
440    fn test_creating_multiples() {
441        let key = "key_2";
442
443        let mut last = 0;
444        for _ in 0..10 {
445            assert!(CURL.lock().unwrap().get(key).is_none());
446            let c = Curl::get_instance(key);
447            let now = Arc::as_ptr(&c.context) as usize;
448
449            assert_ne!(now, last);
450            last = now;
451            assert!(CURL.lock().unwrap().get(key).is_some());
452        }
453
454        assert!(CURL.lock().unwrap().get(key).is_none());
455    }
456
457    #[test]
458    fn test_drop_releases_instance() {
459        let key = "key_3";
460
461        let curl_service_1 = Curl::get_instance(key);
462        let curl_service_2 = Curl::get_instance(key);
463        assert!(CURL.lock().unwrap().get(key).is_some());
464
465        drop(curl_service_1);
466        assert!(CURL.lock().unwrap().get(key).is_some());
467
468        drop(curl_service_2);
469        assert!(CURL.lock().unwrap().get(key).is_none());
470    }
471
472    #[tokio::test]
473    async fn test_shutdown_kills_requests() {
474        let key = "key_4";
475
476        let server = MockServer::start().await;
477
478        Mock::given(method(WiremockMethod::GET))
479            .and(path("/test"))
480            .respond_with(
481                ResponseTemplate::new(200)
482                    .set_body_string("{\"success\": true}")
483                    .set_delay(Duration::from_millis(10_000)),
484            )
485            .mount(&server)
486            .await;
487
488        let shutdown = Arc::new(AtomicBool::new(false));
489        let shutdown_clone = shutdown.clone();
490        let handle = task::spawn(async move {
491            let curl = Curl::get_instance(key);
492            curl.send(
493                &HttpMethod::GET,
494                &RequestArgs {
495                    is_shutdown: Some(shutdown_clone),
496                    url: format!("{}/test", server.uri()),
497                    ..RequestArgs::new()
498                },
499            )
500            .await;
501        });
502
503        let start = Instant::now();
504        shutdown.store(true, std::sync::atomic::Ordering::SeqCst);
505        handle.await.unwrap();
506
507        assert_le!(start.elapsed().as_millis(), 100);
508        time::sleep(Duration::from_millis(100)).await;
509        assert!(CURL.lock().unwrap().get(key).is_none());
510    }
511
512    #[tokio::test]
513    async fn test_statsig_shutdown_kills_thread() {
514        let key = "sdk_key_5";
515        let statsig = Statsig::new(key, None);
516
517        let _ = statsig.initialize().await;
518        assert!(CURL.lock().unwrap().get(key).is_some());
519
520        let _ = statsig.shutdown().await;
521        drop(statsig);
522
523        tokio::time::sleep(Duration::from_millis(1)).await;
524        assert!(CURL.lock().unwrap().get(key).is_none());
525    }
526
527    #[tokio::test]
528    async fn test_thread_dies_on_drop() {
529        let key = "sdk_key_6";
530        let curl = Curl::get_instance(key);
531        let handle = curl.context._handle.clone();
532
533        assert!(!handle.as_ref().is_some_and(|h| h.is_finished()));
534        drop(curl);
535
536        tokio::time::sleep(Duration::from_millis(100)).await;
537        assert!(handle.as_ref().is_some_and(|h| h.is_finished()));
538    }
539}