statsig_rust/
statsig_runtime.rs

1use futures::future::join_all;
2use std::collections::HashMap;
3use std::future::Future;
4use std::sync::atomic::AtomicBool;
5use std::sync::{Arc, Mutex, Weak};
6use tokio::runtime::{Builder, Handle, Runtime};
7use tokio::sync::Notify;
8use tokio::task::JoinHandle;
9
10use crate::log_e;
11use crate::StatsigErr;
12use crate::{log_d, log_w};
13
14const TAG: &str = stringify!(StatsigRuntime);
15
16lazy_static::lazy_static! {
17    static ref OWNED_TOKIO_RUNTIME: Mutex<Option<Weak<Runtime>>> = Mutex::new(None);
18}
19
20#[derive(Debug, Clone, PartialEq, Eq, Hash)]
21struct TaskId {
22    tag: String,
23    tokio_id: tokio::task::Id,
24}
25
26pub struct StatsigRuntime {
27    pub runtime_handle: Handle,
28    inner_runtime: Mutex<Option<Arc<Runtime>>>,
29    spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
30    shutdown_notify: Arc<Notify>,
31    is_shutdown: Arc<AtomicBool>,
32}
33
34impl StatsigRuntime {
35    #[must_use]
36    pub fn get_runtime() -> Arc<StatsigRuntime> {
37        let (opt_runtime, runtime_handle) = create_runtime_if_required();
38        let shutdown_notify = Notify::new();
39
40        Arc::new(StatsigRuntime {
41            inner_runtime: Mutex::new(opt_runtime),
42            runtime_handle,
43            spawned_tasks: Arc::new(Mutex::new(HashMap::new())),
44            shutdown_notify: Arc::new(shutdown_notify),
45            is_shutdown: Arc::new(AtomicBool::new(false)),
46        })
47    }
48
49    pub fn get_handle(&self) -> Handle {
50        self.runtime_handle.clone()
51    }
52
53    pub fn get_num_active_tasks(&self) -> usize {
54        match self.spawned_tasks.lock() {
55            Ok(lock) => lock.len(),
56            Err(e) => {
57                log_e!(TAG, "Failed to lock spawned tasks {}", e);
58                0
59            }
60        }
61    }
62
63    pub fn shutdown(&self) {
64        self.shutdown_notify.notify_waiters();
65
66        if let Ok(mut lock) = self.spawned_tasks.lock() {
67            for (_, task) in lock.drain() {
68                task.abort();
69            }
70        }
71    }
72
73    pub fn spawn<F, Fut>(&self, tag: &str, task: F) -> tokio::task::Id
74    where
75        F: FnOnce(Arc<Notify>) -> Fut + Send + 'static,
76        Fut: Future<Output = ()> + Send + 'static,
77    {
78        let tag_string = tag.to_string();
79        let shutdown_notify = self.shutdown_notify.clone();
80        let spawned_tasks = self.spawned_tasks.clone();
81        let is_shutdown = self.is_shutdown.clone();
82
83        log_d!(TAG, "Spawning task {}", tag);
84
85        let handle = self.runtime_handle.spawn(async move {
86            if is_shutdown.load(std::sync::atomic::Ordering::Relaxed) {
87                return;
88            }
89
90            let task_id = tokio::task::id();
91            log_d!(TAG, "Executing task {}.{}", tag_string, task_id);
92            task(shutdown_notify).await;
93            remove_join_handle_with_id(spawned_tasks, tag_string, &task_id);
94        });
95
96        self.insert_join_handle(tag, handle)
97    }
98
99    pub async fn await_tasks_with_tag(&self, tag: &str) {
100        let mut handles = Vec::new();
101
102        match self.spawned_tasks.lock() {
103            Ok(mut lock) => {
104                let keys: Vec<TaskId> = lock.keys().cloned().collect();
105                for key in &keys {
106                    if key.tag == tag {
107                        let removed = if let Some(handle) = lock.remove(key) {
108                            handle
109                        } else {
110                            log_e!(TAG, "No running task found for tag {}", tag);
111                            continue;
112                        };
113
114                        handles.push(removed);
115                    }
116                }
117            }
118            Err(e) => {
119                log_e!(TAG, "Failed to lock spawned tasks {}", e);
120                return;
121            }
122        };
123
124        join_all(handles).await;
125    }
126
127    pub async fn await_join_handle(
128        &self,
129        tag: &str,
130        handle_id: &tokio::task::Id,
131    ) -> Result<(), StatsigErr> {
132        let task_id = TaskId {
133            tag: tag.to_string(),
134            tokio_id: *handle_id,
135        };
136
137        let handle = match self.spawned_tasks.lock() {
138            Ok(mut lock) => match lock.remove(&task_id) {
139                Some(handle) => handle,
140                None => {
141                    return Err(StatsigErr::ThreadFailure(
142                        "No running task found".to_string(),
143                    ));
144                }
145            },
146            Err(e) => {
147                log_e!(
148                    TAG,
149                    "An error occurred while getting join handle with id: {}: {}",
150                    handle_id,
151                    e.to_string()
152                );
153                return Err(StatsigErr::ThreadFailure(e.to_string()));
154            }
155        };
156
157        handle
158            .await
159            .map_err(|e| StatsigErr::ThreadFailure(e.to_string()))?;
160
161        Ok(())
162    }
163
164    fn insert_join_handle(&self, tag: &str, handle: JoinHandle<()>) -> tokio::task::Id {
165        let handle_id = handle.id();
166        let task_id = TaskId {
167            tag: tag.to_string(),
168            tokio_id: handle_id,
169        };
170
171        match self.spawned_tasks.lock() {
172            Ok(mut lock) => {
173                lock.insert(task_id, handle);
174            }
175            Err(e) => {
176                log_e!(
177                    TAG,
178                    "An error occurred while inserting join handle: {}",
179                    e.to_string()
180                );
181            }
182        }
183
184        handle_id
185    }
186}
187
188fn remove_join_handle_with_id(
189    spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
190    tag: String,
191    handle_id: &tokio::task::Id,
192) {
193    let task_id = TaskId {
194        tag,
195        tokio_id: *handle_id,
196    };
197
198    match spawned_tasks.lock() {
199        Ok(mut lock) => {
200            lock.remove(&task_id);
201        }
202        Err(e) => {
203            log_e!(
204                TAG,
205                "An error occurred while removing join handle {}",
206                e.to_string()
207            );
208        }
209    }
210}
211
212fn create_runtime_if_required() -> (Option<Arc<Runtime>>, Handle) {
213    if let Ok(handle) = Handle::try_current() {
214        log_d!(TAG, "Existing tokio runtime found");
215        return (None, handle);
216    }
217
218    let mut lock = OWNED_TOKIO_RUNTIME
219        .lock()
220        .expect("Failed to lock owned tokio runtime");
221
222    match lock.as_ref().and_then(|rt| rt.upgrade()) {
223        Some(rt) => (Some(rt.clone()), rt.handle().clone()),
224        None => {
225            let rt = Arc::new(
226                Builder::new_multi_thread()
227                    .worker_threads(5)
228                    .thread_name("statsig")
229                    .enable_all()
230                    .build()
231                    .expect("Failed to find or create a tokio Runtime"),
232            );
233
234            let handle = rt.handle().clone();
235            lock.replace(Arc::downgrade(&rt));
236            (Some(rt), handle)
237        }
238    }
239}
240
241impl Drop for StatsigRuntime {
242    fn drop(&mut self) {
243        self.shutdown();
244
245        let opt_inner = match self.inner_runtime.lock() {
246            Ok(mut inner_runtime) => inner_runtime.take(),
247            Err(e) => {
248                log_e!(TAG, "Failed to lock inner runtime {}", e);
249                None
250            }
251        };
252
253        let inner = match opt_inner {
254            Some(inner) => inner,
255            None => {
256                log_d!(TAG, "Runtime owned by tokio");
257                return;
258            }
259        };
260
261        if Arc::strong_count(&inner) > 1 {
262            // Another instance is still using the Runtime, so we can't drop it
263            return;
264        }
265
266        if tokio::runtime::Handle::try_current().is_err() {
267            // Not inside the Tokio runtime. Will automatically drop(inner).
268            return;
269        }
270
271        log_w!(TAG, "Attempt to shutdown runtime from inside runtime");
272        std::thread::spawn(move || {
273            // We should not drop from inside the runtime, but in the odd case we do,
274            // moving inner to a new thread will prevent a panic
275            drop(inner);
276        });
277    }
278}