statsig_rust/
statsig_runtime.rs

1use crate::statsig_global::StatsigGlobal;
2use crate::StatsigErr;
3use crate::{log_d, log_e};
4use futures::future::join_all;
5use parking_lot::Mutex;
6use std::collections::HashMap;
7use std::future::Future;
8use std::sync::atomic::AtomicBool;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::runtime::{Builder, Handle, Runtime};
12use tokio::sync::Notify;
13use tokio::task::JoinHandle;
14
15const TAG: &str = stringify!(StatsigRuntime);
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18struct TaskId {
19    tag: String,
20    tokio_id: tokio::task::Id,
21}
22
23pub struct StatsigRuntime {
24    spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
25    shutdown_notify: Arc<Notify>,
26    is_shutdown: Arc<AtomicBool>,
27}
28
29impl StatsigRuntime {
30    #[must_use]
31    pub fn get_runtime() -> Arc<StatsigRuntime> {
32        create_runtime_if_required();
33
34        Arc::new(StatsigRuntime {
35            spawned_tasks: Arc::new(Mutex::new(HashMap::new())),
36            shutdown_notify: Arc::new(Notify::new()),
37            is_shutdown: Arc::new(AtomicBool::new(false)),
38        })
39    }
40
41    pub fn get_handle(&self) -> Result<Handle, StatsigErr> {
42        if let Ok(handle) = Handle::try_current() {
43            return Ok(handle);
44        }
45
46        let global = StatsigGlobal::get();
47        let rt = global
48            .tokio_runtime
49            .try_lock_for(Duration::from_secs(5))
50            .ok_or_else(|| StatsigErr::LockFailure("Failed to lock tokio runtime".to_string()))?;
51
52        if let Some(rt) = rt.as_ref() {
53            return Ok(rt.handle().clone());
54        }
55
56        Err(StatsigErr::ThreadFailure(
57            "No tokio runtime found".to_string(),
58        ))
59    }
60
61    pub fn get_num_active_tasks(&self) -> usize {
62        match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
63            Some(lock) => lock.len(),
64            None => {
65                log_e!(TAG, "Failed to lock spawned tasks for get_num_active_tasks");
66                0
67            }
68        }
69    }
70
71    pub fn shutdown(&self) {
72        self.shutdown_notify.notify_waiters();
73
74        match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
75            Some(mut lock) => {
76                for (_, task) in lock.drain() {
77                    task.abort();
78                }
79            }
80            None => {
81                log_e!(TAG, "Failed to lock spawned tasks for shutdown");
82            }
83        }
84    }
85
86    pub fn spawn<F, Fut>(&self, tag: &str, task: F) -> Result<tokio::task::Id, StatsigErr>
87    where
88        F: FnOnce(Arc<Notify>) -> Fut + Send + 'static,
89        Fut: Future<Output = ()> + Send + 'static,
90    {
91        let tag_string = tag.to_string();
92        let shutdown_notify = self.shutdown_notify.clone();
93        let spawned_tasks = self.spawned_tasks.clone();
94        let is_shutdown = self.is_shutdown.clone();
95
96        log_d!(TAG, "Spawning task {}", tag);
97
98        let handle = self.get_handle()?.spawn(async move {
99            if is_shutdown.load(std::sync::atomic::Ordering::Relaxed) {
100                return;
101            }
102
103            let task_id = tokio::task::id();
104            log_d!(TAG, "Executing task {}.{}", tag_string, task_id);
105            task(shutdown_notify).await;
106            remove_join_handle_with_id(spawned_tasks, tag_string, &task_id);
107        });
108
109        Ok(self.insert_join_handle(tag, handle))
110    }
111
112    pub async fn await_tasks_with_tag(&self, tag: &str) {
113        let mut handles = Vec::new();
114
115        match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
116            Some(mut lock) => {
117                let keys: Vec<TaskId> = lock.keys().cloned().collect();
118                for key in &keys {
119                    if key.tag == tag {
120                        let removed = if let Some(handle) = lock.remove(key) {
121                            handle
122                        } else {
123                            log_e!(TAG, "No running task found for tag {}", tag);
124                            continue;
125                        };
126
127                        handles.push(removed);
128                    }
129                }
130            }
131            None => {
132                log_e!(TAG, "Failed to lock spawned tasks for await_tasks_with_tag");
133                return;
134            }
135        };
136
137        join_all(handles).await;
138    }
139
140    pub async fn await_join_handle(
141        &self,
142        tag: &str,
143        handle_id: &tokio::task::Id,
144    ) -> Result<(), StatsigErr> {
145        let task_id = TaskId {
146            tag: tag.to_string(),
147            tokio_id: *handle_id,
148        };
149
150        let handle = match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
151            Some(mut lock) => match lock.remove(&task_id) {
152                Some(handle) => handle,
153                None => {
154                    return Err(StatsigErr::ThreadFailure(
155                        "No running task found".to_string(),
156                    ));
157                }
158            },
159            None => {
160                log_e!(TAG, "Failed to lock spawned tasks for await_join_handle");
161                return Err(StatsigErr::ThreadFailure(
162                    "Failed to lock spawned tasks".to_string(),
163                ));
164            }
165        };
166
167        handle
168            .await
169            .map_err(|e| StatsigErr::ThreadFailure(e.to_string()))?;
170
171        Ok(())
172    }
173
174    fn insert_join_handle(&self, tag: &str, handle: JoinHandle<()>) -> tokio::task::Id {
175        let handle_id = handle.id();
176        let task_id = TaskId {
177            tag: tag.to_string(),
178            tokio_id: handle_id,
179        };
180
181        match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
182            Some(mut lock) => {
183                lock.insert(task_id, handle);
184            }
185            None => {
186                log_e!(TAG, "Failed to lock spawned tasks for insert_join_handle");
187            }
188        }
189
190        handle_id
191    }
192}
193
194pub fn create_new_runtime() -> Runtime {
195    #[cfg(not(target_family = "wasm"))]
196    return Builder::new_multi_thread()
197        .worker_threads(5)
198        .thread_name("statsig")
199        .enable_all()
200        .build()
201        .expect("Failed to create a tokio Runtime");
202
203    #[cfg(target_family = "wasm")]
204    return Builder::new_current_thread()
205        .thread_name("statsig")
206        .enable_all()
207        .build()
208        .expect("Failed to create a tokio Runtime (single-threaded for wasm");
209}
210
211fn remove_join_handle_with_id(
212    spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
213    tag: String,
214    handle_id: &tokio::task::Id,
215) {
216    let task_id = TaskId {
217        tag,
218        tokio_id: *handle_id,
219    };
220
221    match spawned_tasks.try_lock_for(Duration::from_secs(5)) {
222        Some(mut lock) => {
223            lock.remove(&task_id);
224        }
225        None => {
226            log_e!(
227                TAG,
228                "Failed to lock spawned tasks for remove_join_handle_with_id"
229            );
230        }
231    }
232}
233
234fn create_runtime_if_required() {
235    if Handle::try_current().is_ok() {
236        log_d!(TAG, "External tokio runtime found");
237        return;
238    }
239
240    let global = StatsigGlobal::get();
241    let mut lock = global
242        .tokio_runtime
243        .try_lock_for(Duration::from_secs(5))
244        .expect("Failed to lock owned tokio runtime");
245
246    match lock.as_ref() {
247        Some(_) => {
248            log_d!(TAG, "Existing StatsigGlobal tokio runtime found");
249        }
250        None => {
251            log_d!(TAG, "Creating new tokio runtime for StatsigGlobal");
252            let rt = Arc::new(create_new_runtime());
253
254            lock.replace(rt);
255        }
256    };
257}
258
259impl Drop for StatsigRuntime {
260    fn drop(&mut self) {
261        self.shutdown();
262
263        // let opt_inner = match self.inner_runtime.lock() {
264        //     Ok(mut inner_runtime) => inner_runtime.take(),
265        //     Err(e) => {
266        //         log_e!(TAG, "Failed to lock inner runtime {}", e);
267        //         None
268        //     }
269        // };
270
271        // let inner = match opt_inner {
272        //     Some(inner) => inner,
273        //     None => {
274        //         log_d!(TAG, "Runtime owned by tokio");
275        //         return;
276        //     }
277        // };
278
279        // if Arc::strong_count(&inner) > 1 {
280        //     // Another instance is still using the Runtime, so we can't drop it
281        //     return;
282        // }
283
284        // if tokio::runtime::Handle::try_current().is_err() {
285        //     println!("Not inside the Tokio runtime. Will automatically drop(inner).");
286        //     // Not inside the Tokio runtime. Will automatically drop(inner).
287        //     return;
288        // }
289
290        // log_w!(TAG, "Attempt to shutdown runtime from inside runtime");
291        // std::thread::spawn(move || {
292        //     println!("Dropping inner runtime from outside the Tokio runtime");
293        //     // We should not drop from inside the runtime, but in the odd case we do,
294        //     // moving inner to a new thread will prevent a panic
295        //     drop(inner);
296        // });
297    }
298}