statsig_rust/
statsig_runtime.rs

1use crate::statsig_global::StatsigGlobal;
2use crate::StatsigErr;
3use crate::{log_d, log_e};
4use futures::future::join_all;
5use parking_lot::Mutex;
6use std::collections::HashMap;
7use std::future::Future;
8use std::sync::atomic::AtomicBool;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::runtime::{Builder, Handle, Runtime};
12use tokio::sync::Notify;
13use tokio::task::JoinHandle;
14
15const TAG: &str = stringify!(StatsigRuntime);
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18struct TaskId {
19    tag: String,
20    tokio_id: tokio::task::Id,
21}
22
23pub struct StatsigRuntime {
24    spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
25    shutdown_notify: Arc<Notify>,
26    is_shutdown: Arc<AtomicBool>,
27}
28
29impl StatsigRuntime {
30    #[must_use]
31    pub fn get_runtime() -> Arc<StatsigRuntime> {
32        create_runtime_if_required();
33
34        Arc::new(StatsigRuntime {
35            spawned_tasks: Arc::new(Mutex::new(HashMap::new())),
36            shutdown_notify: Arc::new(Notify::new()),
37            is_shutdown: Arc::new(AtomicBool::new(false)),
38        })
39    }
40
41    pub fn get_handle(&self) -> Result<Handle, StatsigErr> {
42        if let Ok(handle) = Handle::try_current() {
43            return Ok(handle);
44        }
45
46        let global = StatsigGlobal::get();
47        let mut rt = global
48            .tokio_runtime
49            .try_lock_for(Duration::from_secs(5))
50            .ok_or_else(|| StatsigErr::LockFailure("Failed to lock tokio runtime".to_string()))?;
51        if rt.is_none() {
52            *rt = Some(Arc::new(create_new_runtime()));
53        }
54        if let Some(rt) = rt.as_ref() {
55            return Ok(rt.handle().clone());
56        }
57
58        Err(StatsigErr::ThreadFailure(
59            "No tokio runtime found".to_string(),
60        ))
61    }
62
63    pub fn get_num_active_tasks(&self) -> usize {
64        match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
65            Some(lock) => lock.len(),
66            None => {
67                log_e!(TAG, "Failed to lock spawned tasks for get_num_active_tasks");
68                0
69            }
70        }
71    }
72
73    pub fn shutdown(&self) {
74        self.shutdown_notify.notify_waiters();
75
76        match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
77            Some(mut lock) => {
78                for (_, task) in lock.drain() {
79                    task.abort();
80                }
81            }
82            None => {
83                log_e!(TAG, "Failed to lock spawned tasks for shutdown");
84            }
85        }
86    }
87
88    pub fn spawn<F, Fut>(&self, tag: &str, task: F) -> Result<tokio::task::Id, StatsigErr>
89    where
90        F: FnOnce(Arc<Notify>) -> Fut + Send + 'static,
91        Fut: Future<Output = ()> + Send + 'static,
92    {
93        let tag_string = tag.to_string();
94        let shutdown_notify = self.shutdown_notify.clone();
95        let spawned_tasks = self.spawned_tasks.clone();
96        let is_shutdown = self.is_shutdown.clone();
97
98        log_d!(TAG, "Spawning task {}", tag);
99
100        let handle = self.get_handle()?.spawn(async move {
101            if is_shutdown.load(std::sync::atomic::Ordering::Relaxed) {
102                return;
103            }
104
105            let task_id = tokio::task::id();
106            log_d!(TAG, "Executing task {}.{}", tag_string, task_id);
107            task(shutdown_notify).await;
108            remove_join_handle_with_id(spawned_tasks, tag_string, &task_id);
109        });
110
111        Ok(self.insert_join_handle(tag, handle))
112    }
113
114    pub async fn await_tasks_with_tag(&self, tag: &str) {
115        let mut handles = Vec::new();
116
117        match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
118            Some(mut lock) => {
119                let keys: Vec<TaskId> = lock.keys().cloned().collect();
120                for key in &keys {
121                    if key.tag == tag {
122                        let removed = if let Some(handle) = lock.remove(key) {
123                            handle
124                        } else {
125                            log_e!(TAG, "No running task found for tag {}", tag);
126                            continue;
127                        };
128
129                        handles.push(removed);
130                    }
131                }
132            }
133            None => {
134                log_e!(TAG, "Failed to lock spawned tasks for await_tasks_with_tag");
135                return;
136            }
137        };
138
139        join_all(handles).await;
140    }
141
142    pub async fn await_join_handle(
143        &self,
144        tag: &str,
145        handle_id: &tokio::task::Id,
146    ) -> Result<(), StatsigErr> {
147        let task_id = TaskId {
148            tag: tag.to_string(),
149            tokio_id: *handle_id,
150        };
151
152        let handle = match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
153            Some(mut lock) => match lock.remove(&task_id) {
154                Some(handle) => handle,
155                None => {
156                    return Err(StatsigErr::ThreadFailure(
157                        "No running task found".to_string(),
158                    ));
159                }
160            },
161            None => {
162                log_e!(TAG, "Failed to lock spawned tasks for await_join_handle");
163                return Err(StatsigErr::ThreadFailure(
164                    "Failed to lock spawned tasks".to_string(),
165                ));
166            }
167        };
168
169        handle
170            .await
171            .map_err(|e| StatsigErr::ThreadFailure(e.to_string()))?;
172
173        Ok(())
174    }
175
176    fn insert_join_handle(&self, tag: &str, handle: JoinHandle<()>) -> tokio::task::Id {
177        let handle_id = handle.id();
178        let task_id = TaskId {
179            tag: tag.to_string(),
180            tokio_id: handle_id,
181        };
182
183        match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
184            Some(mut lock) => {
185                lock.insert(task_id, handle);
186            }
187            None => {
188                log_e!(TAG, "Failed to lock spawned tasks for insert_join_handle");
189            }
190        }
191
192        handle_id
193    }
194}
195
196pub fn create_new_runtime() -> Runtime {
197    #[cfg(not(target_family = "wasm"))]
198    return Builder::new_multi_thread()
199        .worker_threads(5)
200        .thread_name("statsig")
201        .enable_all()
202        .build()
203        .expect("Failed to create a tokio Runtime");
204
205    #[cfg(target_family = "wasm")]
206    return Builder::new_current_thread()
207        .thread_name("statsig")
208        .enable_all()
209        .build()
210        .expect("Failed to create a tokio Runtime (single-threaded for wasm");
211}
212
213fn remove_join_handle_with_id(
214    spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
215    tag: String,
216    handle_id: &tokio::task::Id,
217) {
218    let task_id = TaskId {
219        tag,
220        tokio_id: *handle_id,
221    };
222
223    match spawned_tasks.try_lock_for(Duration::from_secs(5)) {
224        Some(mut lock) => {
225            lock.remove(&task_id);
226        }
227        None => {
228            log_e!(
229                TAG,
230                "Failed to lock spawned tasks for remove_join_handle_with_id"
231            );
232        }
233    }
234}
235
236fn create_runtime_if_required() {
237    if Handle::try_current().is_ok() {
238        log_d!(TAG, "External tokio runtime found");
239        return;
240    }
241
242    let global = StatsigGlobal::get();
243    let mut lock = global
244        .tokio_runtime
245        .try_lock_for(Duration::from_secs(5))
246        .expect("Failed to lock owned tokio runtime");
247
248    match lock.as_ref() {
249        Some(_) => {
250            log_d!(TAG, "Existing StatsigGlobal tokio runtime found");
251        }
252        None => {
253            log_d!(TAG, "Creating new tokio runtime for StatsigGlobal");
254            let rt = Arc::new(create_new_runtime());
255
256            lock.replace(rt);
257        }
258    };
259}
260
261impl Drop for StatsigRuntime {
262    fn drop(&mut self) {
263        self.shutdown();
264
265        // let opt_inner = match self.inner_runtime.lock() {
266        //     Ok(mut inner_runtime) => inner_runtime.take(),
267        //     Err(e) => {
268        //         log_e!(TAG, "Failed to lock inner runtime {}", e);
269        //         None
270        //     }
271        // };
272
273        // let inner = match opt_inner {
274        //     Some(inner) => inner,
275        //     None => {
276        //         log_d!(TAG, "Runtime owned by tokio");
277        //         return;
278        //     }
279        // };
280
281        // if Arc::strong_count(&inner) > 1 {
282        //     // Another instance is still using the Runtime, so we can't drop it
283        //     return;
284        // }
285
286        // if tokio::runtime::Handle::try_current().is_err() {
287        //     println!("Not inside the Tokio runtime. Will automatically drop(inner).");
288        //     // Not inside the Tokio runtime. Will automatically drop(inner).
289        //     return;
290        // }
291
292        // log_w!(TAG, "Attempt to shutdown runtime from inside runtime");
293        // std::thread::spawn(move || {
294        //     println!("Dropping inner runtime from outside the Tokio runtime");
295        //     // We should not drop from inside the runtime, but in the odd case we do,
296        //     // moving inner to a new thread will prevent a panic
297        //     drop(inner);
298        // });
299    }
300}