statsig_rust/
statsig_runtime.rs1use crate::statsig_global::StatsigGlobal;
2use crate::StatsigErr;
3use crate::{log_d, log_e};
4use futures::future::join_all;
5use parking_lot::Mutex;
6use std::collections::HashMap;
7use std::future::Future;
8use std::sync::atomic::AtomicBool;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::runtime::{Builder, Handle, Runtime};
12use tokio::sync::Notify;
13use tokio::task::JoinHandle;
14
15const TAG: &str = stringify!(StatsigRuntime);
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18struct TaskId {
19 tag: String,
20 tokio_id: tokio::task::Id,
21}
22
23pub struct StatsigRuntime {
24 spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
25 shutdown_notify: Arc<Notify>,
26 is_shutdown: Arc<AtomicBool>,
27}
28
29impl StatsigRuntime {
30 #[must_use]
31 pub fn get_runtime() -> Arc<StatsigRuntime> {
32 create_runtime_if_required();
33
34 Arc::new(StatsigRuntime {
35 spawned_tasks: Arc::new(Mutex::new(HashMap::new())),
36 shutdown_notify: Arc::new(Notify::new()),
37 is_shutdown: Arc::new(AtomicBool::new(false)),
38 })
39 }
40
41 pub fn get_handle(&self) -> Result<Handle, StatsigErr> {
42 if let Ok(handle) = Handle::try_current() {
43 return Ok(handle);
44 }
45
46 let global = StatsigGlobal::get();
47 let mut rt = global
48 .tokio_runtime
49 .try_lock_for(Duration::from_secs(5))
50 .ok_or_else(|| StatsigErr::LockFailure("Failed to lock tokio runtime".to_string()))?;
51 if rt.is_none() {
52 *rt = Some(Arc::new(create_new_runtime()));
53 }
54 if let Some(rt) = rt.as_ref() {
55 return Ok(rt.handle().clone());
56 }
57
58 Err(StatsigErr::ThreadFailure(
59 "No tokio runtime found".to_string(),
60 ))
61 }
62
63 pub fn get_num_active_tasks(&self) -> usize {
64 match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
65 Some(lock) => lock.len(),
66 None => {
67 log_e!(TAG, "Failed to lock spawned tasks for get_num_active_tasks");
68 0
69 }
70 }
71 }
72
73 pub fn shutdown(&self) {
74 self.shutdown_notify.notify_waiters();
75
76 match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
77 Some(mut lock) => {
78 for (_, task) in lock.drain() {
79 task.abort();
80 }
81 }
82 None => {
83 log_e!(TAG, "Failed to lock spawned tasks for shutdown");
84 }
85 }
86 }
87
88 pub fn spawn<F, Fut>(&self, tag: &str, task: F) -> Result<tokio::task::Id, StatsigErr>
89 where
90 F: FnOnce(Arc<Notify>) -> Fut + Send + 'static,
91 Fut: Future<Output = ()> + Send + 'static,
92 {
93 let tag_string = tag.to_string();
94 let shutdown_notify = self.shutdown_notify.clone();
95 let spawned_tasks = self.spawned_tasks.clone();
96 let is_shutdown = self.is_shutdown.clone();
97
98 log_d!(TAG, "Spawning task {}", tag);
99
100 let handle = self.get_handle()?.spawn(async move {
101 if is_shutdown.load(std::sync::atomic::Ordering::Relaxed) {
102 return;
103 }
104
105 let task_id = tokio::task::id();
106 log_d!(TAG, "Executing task {}.{}", tag_string, task_id);
107 task(shutdown_notify).await;
108 remove_join_handle_with_id(spawned_tasks, tag_string, &task_id);
109 });
110
111 Ok(self.insert_join_handle(tag, handle))
112 }
113
114 pub async fn await_tasks_with_tag(&self, tag: &str) {
115 let mut handles = Vec::new();
116
117 match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
118 Some(mut lock) => {
119 let keys: Vec<TaskId> = lock.keys().cloned().collect();
120 for key in &keys {
121 if key.tag == tag {
122 let removed = if let Some(handle) = lock.remove(key) {
123 handle
124 } else {
125 log_e!(TAG, "No running task found for tag {}", tag);
126 continue;
127 };
128
129 handles.push(removed);
130 }
131 }
132 }
133 None => {
134 log_e!(TAG, "Failed to lock spawned tasks for await_tasks_with_tag");
135 return;
136 }
137 };
138
139 join_all(handles).await;
140 }
141
142 pub async fn await_join_handle(
143 &self,
144 tag: &str,
145 handle_id: &tokio::task::Id,
146 ) -> Result<(), StatsigErr> {
147 let task_id = TaskId {
148 tag: tag.to_string(),
149 tokio_id: *handle_id,
150 };
151
152 let handle = match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
153 Some(mut lock) => match lock.remove(&task_id) {
154 Some(handle) => handle,
155 None => {
156 return Err(StatsigErr::ThreadFailure(
157 "No running task found".to_string(),
158 ));
159 }
160 },
161 None => {
162 log_e!(TAG, "Failed to lock spawned tasks for await_join_handle");
163 return Err(StatsigErr::ThreadFailure(
164 "Failed to lock spawned tasks".to_string(),
165 ));
166 }
167 };
168
169 handle
170 .await
171 .map_err(|e| StatsigErr::ThreadFailure(e.to_string()))?;
172
173 Ok(())
174 }
175
176 pub fn get_running_task_ids(&self) -> Vec<(String, String)> {
177 let tasks = match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
178 Some(lock) => lock,
179 None => {
180 log_e!(TAG, "Failed to lock spawned tasks for get_running_task_ids");
181 return Vec::new();
182 }
183 };
184
185 tasks
186 .keys()
187 .map(|key| (key.tag.clone(), key.tokio_id.to_string()))
188 .collect()
189 }
190
191 fn insert_join_handle(&self, tag: &str, handle: JoinHandle<()>) -> tokio::task::Id {
192 let handle_id = handle.id();
193 let task_id = TaskId {
194 tag: tag.to_string(),
195 tokio_id: handle_id,
196 };
197
198 match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
199 Some(mut lock) => {
200 lock.insert(task_id, handle);
201 }
202 None => {
203 log_e!(TAG, "Failed to lock spawned tasks for insert_join_handle");
204 }
205 }
206
207 handle_id
208 }
209}
210
211pub fn create_new_runtime() -> Runtime {
212 #[cfg(not(target_family = "wasm"))]
213 return Builder::new_multi_thread()
214 .worker_threads(5)
215 .thread_name("statsig")
216 .enable_all()
217 .build()
218 .expect("Failed to create a tokio Runtime");
219
220 #[cfg(target_family = "wasm")]
221 return Builder::new_current_thread()
222 .thread_name("statsig")
223 .enable_all()
224 .build()
225 .expect("Failed to create a tokio Runtime (single-threaded for wasm");
226}
227
228fn remove_join_handle_with_id(
229 spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
230 tag: String,
231 handle_id: &tokio::task::Id,
232) {
233 let task_id = TaskId {
234 tag,
235 tokio_id: *handle_id,
236 };
237
238 match spawned_tasks.try_lock_for(Duration::from_secs(5)) {
239 Some(mut lock) => {
240 lock.remove(&task_id);
241 }
242 None => {
243 log_e!(
244 TAG,
245 "Failed to lock spawned tasks for remove_join_handle_with_id"
246 );
247 }
248 }
249}
250
251fn create_runtime_if_required() {
252 if Handle::try_current().is_ok() {
253 log_d!(TAG, "External tokio runtime found");
254 return;
255 }
256
257 let global = StatsigGlobal::get();
258 let mut lock = global
259 .tokio_runtime
260 .try_lock_for(Duration::from_secs(5))
261 .expect("Failed to lock owned tokio runtime");
262
263 match lock.as_ref() {
264 Some(_) => {
265 log_d!(TAG, "Existing StatsigGlobal tokio runtime found");
266 }
267 None => {
268 log_d!(TAG, "Creating new tokio runtime for StatsigGlobal");
269 let rt = Arc::new(create_new_runtime());
270
271 lock.replace(rt);
272 }
273 };
274}
275
276impl Drop for StatsigRuntime {
277 fn drop(&mut self) {
278 self.shutdown();
279
280 }
315}