statsig_rust/
statsig_runtime.rs

1use crate::statsig_global::StatsigGlobal;
2use crate::StatsigErr;
3use crate::{log_d, log_e};
4use futures::future::join_all;
5use parking_lot::Mutex;
6use std::collections::HashMap;
7use std::future::Future;
8use std::sync::atomic::AtomicBool;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::runtime::{Builder, Handle};
12use tokio::sync::Notify;
13use tokio::task::JoinHandle;
14
15const TAG: &str = stringify!(StatsigRuntime);
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18struct TaskId {
19    tag: String,
20    tokio_id: tokio::task::Id,
21}
22
23pub struct StatsigRuntime {
24    spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
25    shutdown_notify: Arc<Notify>,
26    is_shutdown: Arc<AtomicBool>,
27}
28
29impl StatsigRuntime {
30    #[must_use]
31    pub fn get_runtime() -> Arc<StatsigRuntime> {
32        create_runtime_if_required();
33
34        Arc::new(StatsigRuntime {
35            spawned_tasks: Arc::new(Mutex::new(HashMap::new())),
36            shutdown_notify: Arc::new(Notify::new()),
37            is_shutdown: Arc::new(AtomicBool::new(false)),
38        })
39    }
40
41    pub fn get_handle(&self) -> Result<Handle, StatsigErr> {
42        if let Ok(handle) = Handle::try_current() {
43            return Ok(handle);
44        }
45
46        let global = StatsigGlobal::get();
47        let rt = global
48            .tokio_runtime
49            .try_lock_for(Duration::from_secs(5))
50            .ok_or_else(|| StatsigErr::LockFailure("Failed to lock tokio runtime".to_string()))?;
51
52        if let Some(rt) = rt.as_ref() {
53            return Ok(rt.handle().clone());
54        }
55
56        Err(StatsigErr::ThreadFailure(
57            "No tokio runtime found".to_string(),
58        ))
59    }
60
61    pub fn get_num_active_tasks(&self) -> usize {
62        match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
63            Some(lock) => lock.len(),
64            None => {
65                log_e!(TAG, "Failed to lock spawned tasks for get_num_active_tasks");
66                0
67            }
68        }
69    }
70
71    pub fn shutdown(&self) {
72        self.shutdown_notify.notify_waiters();
73
74        match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
75            Some(mut lock) => {
76                for (_, task) in lock.drain() {
77                    task.abort();
78                }
79            }
80            None => {
81                log_e!(TAG, "Failed to lock spawned tasks for shutdown");
82            }
83        }
84    }
85
86    pub fn spawn<F, Fut>(&self, tag: &str, task: F) -> Result<tokio::task::Id, StatsigErr>
87    where
88        F: FnOnce(Arc<Notify>) -> Fut + Send + 'static,
89        Fut: Future<Output = ()> + Send + 'static,
90    {
91        let tag_string = tag.to_string();
92        let shutdown_notify = self.shutdown_notify.clone();
93        let spawned_tasks = self.spawned_tasks.clone();
94        let is_shutdown = self.is_shutdown.clone();
95
96        log_d!(TAG, "Spawning task {}", tag);
97
98        let handle = self.get_handle()?.spawn(async move {
99            if is_shutdown.load(std::sync::atomic::Ordering::Relaxed) {
100                return;
101            }
102
103            let task_id = tokio::task::id();
104            log_d!(TAG, "Executing task {}.{}", tag_string, task_id);
105            task(shutdown_notify).await;
106            remove_join_handle_with_id(spawned_tasks, tag_string, &task_id);
107        });
108
109        Ok(self.insert_join_handle(tag, handle))
110    }
111
112    pub async fn await_tasks_with_tag(&self, tag: &str) {
113        let mut handles = Vec::new();
114
115        match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
116            Some(mut lock) => {
117                let keys: Vec<TaskId> = lock.keys().cloned().collect();
118                for key in &keys {
119                    if key.tag == tag {
120                        let removed = if let Some(handle) = lock.remove(key) {
121                            handle
122                        } else {
123                            log_e!(TAG, "No running task found for tag {}", tag);
124                            continue;
125                        };
126
127                        handles.push(removed);
128                    }
129                }
130            }
131            None => {
132                log_e!(TAG, "Failed to lock spawned tasks for await_tasks_with_tag");
133                return;
134            }
135        };
136
137        join_all(handles).await;
138    }
139
140    pub async fn await_join_handle(
141        &self,
142        tag: &str,
143        handle_id: &tokio::task::Id,
144    ) -> Result<(), StatsigErr> {
145        let task_id = TaskId {
146            tag: tag.to_string(),
147            tokio_id: *handle_id,
148        };
149
150        let handle = match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
151            Some(mut lock) => match lock.remove(&task_id) {
152                Some(handle) => handle,
153                None => {
154                    return Err(StatsigErr::ThreadFailure(
155                        "No running task found".to_string(),
156                    ));
157                }
158            },
159            None => {
160                log_e!(TAG, "Failed to lock spawned tasks for await_join_handle");
161                return Err(StatsigErr::ThreadFailure(
162                    "Failed to lock spawned tasks".to_string(),
163                ));
164            }
165        };
166
167        handle
168            .await
169            .map_err(|e| StatsigErr::ThreadFailure(e.to_string()))?;
170
171        Ok(())
172    }
173
174    fn insert_join_handle(&self, tag: &str, handle: JoinHandle<()>) -> tokio::task::Id {
175        let handle_id = handle.id();
176        let task_id = TaskId {
177            tag: tag.to_string(),
178            tokio_id: handle_id,
179        };
180
181        match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
182            Some(mut lock) => {
183                lock.insert(task_id, handle);
184            }
185            None => {
186                log_e!(TAG, "Failed to lock spawned tasks for insert_join_handle");
187            }
188        }
189
190        handle_id
191    }
192}
193
194fn remove_join_handle_with_id(
195    spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
196    tag: String,
197    handle_id: &tokio::task::Id,
198) {
199    let task_id = TaskId {
200        tag,
201        tokio_id: *handle_id,
202    };
203
204    match spawned_tasks.try_lock_for(Duration::from_secs(5)) {
205        Some(mut lock) => {
206            lock.remove(&task_id);
207        }
208        None => {
209            log_e!(
210                TAG,
211                "Failed to lock spawned tasks for remove_join_handle_with_id"
212            );
213        }
214    }
215}
216
217fn create_runtime_if_required() {
218    if Handle::try_current().is_ok() {
219        log_d!(TAG, "External tokio runtime found");
220        return;
221    }
222
223    let global = StatsigGlobal::get();
224    let mut lock = global
225        .tokio_runtime
226        .try_lock_for(Duration::from_secs(5))
227        .expect("Failed to lock owned tokio runtime");
228
229    match lock.as_ref() {
230        Some(_) => {
231            log_d!(TAG, "Existing StatsigGlobal tokio runtime found");
232        }
233        None => {
234            log_d!(TAG, "Creating new tokio runtime for StatsigGlobal");
235            let rt = Arc::new(
236                Builder::new_multi_thread()
237                    .worker_threads(5)
238                    .thread_name("statsig")
239                    .enable_all()
240                    .build()
241                    .expect("Failed to find or create a tokio Runtime"),
242            );
243
244            lock.replace(rt);
245        }
246    };
247}
248
249impl Drop for StatsigRuntime {
250    fn drop(&mut self) {
251        self.shutdown();
252
253        // let opt_inner = match self.inner_runtime.lock() {
254        //     Ok(mut inner_runtime) => inner_runtime.take(),
255        //     Err(e) => {
256        //         log_e!(TAG, "Failed to lock inner runtime {}", e);
257        //         None
258        //     }
259        // };
260
261        // let inner = match opt_inner {
262        //     Some(inner) => inner,
263        //     None => {
264        //         log_d!(TAG, "Runtime owned by tokio");
265        //         return;
266        //     }
267        // };
268
269        // if Arc::strong_count(&inner) > 1 {
270        //     // Another instance is still using the Runtime, so we can't drop it
271        //     return;
272        // }
273
274        // if tokio::runtime::Handle::try_current().is_err() {
275        //     println!("Not inside the Tokio runtime. Will automatically drop(inner).");
276        //     // Not inside the Tokio runtime. Will automatically drop(inner).
277        //     return;
278        // }
279
280        // log_w!(TAG, "Attempt to shutdown runtime from inside runtime");
281        // std::thread::spawn(move || {
282        //     println!("Dropping inner runtime from outside the Tokio runtime");
283        //     // We should not drop from inside the runtime, but in the odd case we do,
284        //     // moving inner to a new thread will prevent a panic
285        //     drop(inner);
286        // });
287    }
288}