statsig_rust/
statsig_runtime.rs

1use crate::statsig_global::StatsigGlobal;
2use crate::StatsigErr;
3use crate::{log_d, log_e};
4use futures::future::join_all;
5use parking_lot::Mutex;
6use std::collections::HashMap;
7use std::future::Future;
8use std::sync::atomic::AtomicBool;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::runtime::{Builder, Handle, Runtime};
12use tokio::sync::Notify;
13use tokio::task::JoinHandle;
14
15const TAG: &str = stringify!(StatsigRuntime);
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18struct TaskId {
19    tag: String,
20    tokio_id: tokio::task::Id,
21}
22
23pub struct StatsigRuntime {
24    spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
25    shutdown_notify: Arc<Notify>,
26    is_shutdown: Arc<AtomicBool>,
27}
28
29impl StatsigRuntime {
30    #[must_use]
31    pub fn get_runtime() -> Arc<StatsigRuntime> {
32        create_runtime_if_required();
33
34        Arc::new(StatsigRuntime {
35            spawned_tasks: Arc::new(Mutex::new(HashMap::new())),
36            shutdown_notify: Arc::new(Notify::new()),
37            is_shutdown: Arc::new(AtomicBool::new(false)),
38        })
39    }
40
41    pub fn get_handle(&self) -> Result<Handle, StatsigErr> {
42        if let Ok(handle) = Handle::try_current() {
43            return Ok(handle);
44        }
45
46        let global = StatsigGlobal::get();
47        let mut rt = global
48            .tokio_runtime
49            .try_lock_for(Duration::from_secs(5))
50            .ok_or_else(|| StatsigErr::LockFailure("Failed to lock tokio runtime".to_string()))?;
51        if rt.is_none() {
52            *rt = Some(Arc::new(create_new_runtime()));
53        }
54        if let Some(rt) = rt.as_ref() {
55            return Ok(rt.handle().clone());
56        }
57
58        Err(StatsigErr::ThreadFailure(
59            "No tokio runtime found".to_string(),
60        ))
61    }
62
63    pub fn get_num_active_tasks(&self) -> usize {
64        match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
65            Some(lock) => lock.len(),
66            None => {
67                log_e!(TAG, "Failed to lock spawned tasks for get_num_active_tasks");
68                0
69            }
70        }
71    }
72
73    pub fn shutdown(&self) {
74        self.shutdown_notify.notify_waiters();
75
76        match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
77            Some(mut lock) => {
78                for (_, task) in lock.drain() {
79                    task.abort();
80                }
81            }
82            None => {
83                log_e!(TAG, "Failed to lock spawned tasks for shutdown");
84            }
85        }
86    }
87
88    pub fn spawn<F, Fut>(&self, tag: &str, task: F) -> Result<tokio::task::Id, StatsigErr>
89    where
90        F: FnOnce(Arc<Notify>) -> Fut + Send + 'static,
91        Fut: Future<Output = ()> + Send + 'static,
92    {
93        let tag_string = tag.to_string();
94        let shutdown_notify = self.shutdown_notify.clone();
95        let spawned_tasks = self.spawned_tasks.clone();
96        let is_shutdown = self.is_shutdown.clone();
97
98        log_d!(TAG, "Spawning task {}", tag);
99
100        let handle = self.get_handle()?.spawn(async move {
101            if is_shutdown.load(std::sync::atomic::Ordering::Relaxed) {
102                return;
103            }
104
105            let task_id = tokio::task::id();
106            log_d!(TAG, "Executing task {}.{}", tag_string, task_id);
107            task(shutdown_notify).await;
108            remove_join_handle_with_id(spawned_tasks, tag_string, &task_id);
109        });
110
111        Ok(self.insert_join_handle(tag, handle))
112    }
113
114    pub async fn await_tasks_with_tag(&self, tag: &str) {
115        let mut handles = Vec::new();
116
117        match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
118            Some(mut lock) => {
119                let keys: Vec<TaskId> = lock.keys().cloned().collect();
120                for key in &keys {
121                    if key.tag == tag {
122                        let removed = if let Some(handle) = lock.remove(key) {
123                            handle
124                        } else {
125                            log_e!(TAG, "No running task found for tag {}", tag);
126                            continue;
127                        };
128
129                        handles.push(removed);
130                    }
131                }
132            }
133            None => {
134                log_e!(TAG, "Failed to lock spawned tasks for await_tasks_with_tag");
135                return;
136            }
137        };
138
139        join_all(handles).await;
140    }
141
142    pub async fn await_join_handle(
143        &self,
144        tag: &str,
145        handle_id: &tokio::task::Id,
146    ) -> Result<(), StatsigErr> {
147        let task_id = TaskId {
148            tag: tag.to_string(),
149            tokio_id: *handle_id,
150        };
151
152        let handle = match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
153            Some(mut lock) => match lock.remove(&task_id) {
154                Some(handle) => handle,
155                None => {
156                    return Err(StatsigErr::ThreadFailure(
157                        "No running task found".to_string(),
158                    ));
159                }
160            },
161            None => {
162                log_e!(TAG, "Failed to lock spawned tasks for await_join_handle");
163                return Err(StatsigErr::ThreadFailure(
164                    "Failed to lock spawned tasks".to_string(),
165                ));
166            }
167        };
168
169        handle
170            .await
171            .map_err(|e| StatsigErr::ThreadFailure(e.to_string()))?;
172
173        Ok(())
174    }
175
176    pub fn get_running_task_ids(&self) -> Vec<(String, String)> {
177        let tasks = match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
178            Some(lock) => lock,
179            None => {
180                log_e!(TAG, "Failed to lock spawned tasks for get_running_task_ids");
181                return Vec::new();
182            }
183        };
184
185        tasks
186            .keys()
187            .map(|key| (key.tag.clone(), key.tokio_id.to_string()))
188            .collect()
189    }
190
191    fn insert_join_handle(&self, tag: &str, handle: JoinHandle<()>) -> tokio::task::Id {
192        let handle_id = handle.id();
193        let task_id = TaskId {
194            tag: tag.to_string(),
195            tokio_id: handle_id,
196        };
197
198        match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
199            Some(mut lock) => {
200                lock.insert(task_id, handle);
201            }
202            None => {
203                log_e!(TAG, "Failed to lock spawned tasks for insert_join_handle");
204            }
205        }
206
207        handle_id
208    }
209}
210
211pub fn create_new_runtime() -> Runtime {
212    #[cfg(not(target_family = "wasm"))]
213    return Builder::new_multi_thread()
214        .worker_threads(5)
215        .thread_name("statsig")
216        .enable_all()
217        .build()
218        .expect("Failed to create a tokio Runtime");
219
220    #[cfg(target_family = "wasm")]
221    return Builder::new_current_thread()
222        .thread_name("statsig")
223        .enable_all()
224        .build()
225        .expect("Failed to create a tokio Runtime (single-threaded for wasm");
226}
227
228fn remove_join_handle_with_id(
229    spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
230    tag: String,
231    handle_id: &tokio::task::Id,
232) {
233    let task_id = TaskId {
234        tag,
235        tokio_id: *handle_id,
236    };
237
238    match spawned_tasks.try_lock_for(Duration::from_secs(5)) {
239        Some(mut lock) => {
240            lock.remove(&task_id);
241        }
242        None => {
243            log_e!(
244                TAG,
245                "Failed to lock spawned tasks for remove_join_handle_with_id"
246            );
247        }
248    }
249}
250
251fn create_runtime_if_required() {
252    if Handle::try_current().is_ok() {
253        log_d!(TAG, "External tokio runtime found");
254        return;
255    }
256
257    let global = StatsigGlobal::get();
258    let mut lock = global
259        .tokio_runtime
260        .try_lock_for(Duration::from_secs(5))
261        .expect("Failed to lock owned tokio runtime");
262
263    match lock.as_ref() {
264        Some(_) => {
265            log_d!(TAG, "Existing StatsigGlobal tokio runtime found");
266        }
267        None => {
268            log_d!(TAG, "Creating new tokio runtime for StatsigGlobal");
269            let rt = Arc::new(create_new_runtime());
270
271            lock.replace(rt);
272        }
273    };
274}
275
276impl Drop for StatsigRuntime {
277    fn drop(&mut self) {
278        self.shutdown();
279
280        // let opt_inner = match self.inner_runtime.lock() {
281        //     Ok(mut inner_runtime) => inner_runtime.take(),
282        //     Err(e) => {
283        //         log_e!(TAG, "Failed to lock inner runtime {}", e);
284        //         None
285        //     }
286        // };
287
288        // let inner = match opt_inner {
289        //     Some(inner) => inner,
290        //     None => {
291        //         log_d!(TAG, "Runtime owned by tokio");
292        //         return;
293        //     }
294        // };
295
296        // if Arc::strong_count(&inner) > 1 {
297        //     // Another instance is still using the Runtime, so we can't drop it
298        //     return;
299        // }
300
301        // if tokio::runtime::Handle::try_current().is_err() {
302        //     println!("Not inside the Tokio runtime. Will automatically drop(inner).");
303        //     // Not inside the Tokio runtime. Will automatically drop(inner).
304        //     return;
305        // }
306
307        // log_w!(TAG, "Attempt to shutdown runtime from inside runtime");
308        // std::thread::spawn(move || {
309        //     println!("Dropping inner runtime from outside the Tokio runtime");
310        //     // We should not drop from inside the runtime, but in the odd case we do,
311        //     // moving inner to a new thread will prevent a panic
312        //     drop(inner);
313        // });
314    }
315}