statsig_rust/
statsig_runtime.rs1use crate::statsig_global::StatsigGlobal;
2use crate::StatsigErr;
3use crate::{log_d, log_e};
4use futures::future::join_all;
5use parking_lot::Mutex;
6use std::collections::HashMap;
7use std::future::Future;
8use std::sync::atomic::AtomicBool;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::runtime::{Builder, Handle, Runtime};
12use tokio::sync::Notify;
13use tokio::task::JoinHandle;
14
15const TAG: &str = stringify!(StatsigRuntime);
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18struct TaskId {
19 tag: String,
20 tokio_id: tokio::task::Id,
21}
22
23pub struct StatsigRuntime {
24 spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
25 shutdown_notify: Arc<Notify>,
26 is_shutdown: Arc<AtomicBool>,
27}
28
29impl StatsigRuntime {
30 #[must_use]
31 pub fn get_runtime() -> Arc<StatsigRuntime> {
32 create_runtime_if_required();
33
34 Arc::new(StatsigRuntime {
35 spawned_tasks: Arc::new(Mutex::new(HashMap::new())),
36 shutdown_notify: Arc::new(Notify::new()),
37 is_shutdown: Arc::new(AtomicBool::new(false)),
38 })
39 }
40
41 pub fn get_handle(&self) -> Result<Handle, StatsigErr> {
42 if let Ok(handle) = Handle::try_current() {
43 return Ok(handle);
44 }
45
46 let global = StatsigGlobal::get();
47 let mut rt = global
48 .tokio_runtime
49 .try_lock_for(Duration::from_secs(5))
50 .ok_or_else(|| StatsigErr::LockFailure("Failed to lock tokio runtime".to_string()))?;
51 if rt.is_none() {
52 *rt = Some(Arc::new(create_new_runtime()));
53 }
54 if let Some(rt) = rt.as_ref() {
55 return Ok(rt.handle().clone());
56 }
57
58 Err(StatsigErr::ThreadFailure(
59 "No tokio runtime found".to_string(),
60 ))
61 }
62
63 pub fn get_num_active_tasks(&self) -> usize {
64 match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
65 Some(lock) => lock.len(),
66 None => {
67 log_e!(TAG, "Failed to lock spawned tasks for get_num_active_tasks");
68 0
69 }
70 }
71 }
72
73 pub fn shutdown(&self) {
74 self.shutdown_notify.notify_waiters();
75
76 match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
77 Some(mut lock) => {
78 for (_, task) in lock.drain() {
79 task.abort();
80 }
81 }
82 None => {
83 log_e!(TAG, "Failed to lock spawned tasks for shutdown");
84 }
85 }
86 }
87
88 pub fn spawn<F, Fut>(&self, tag: &str, task: F) -> Result<tokio::task::Id, StatsigErr>
89 where
90 F: FnOnce(Arc<Notify>) -> Fut + Send + 'static,
91 Fut: Future<Output = ()> + Send + 'static,
92 {
93 let tag_string = tag.to_string();
94 let shutdown_notify = self.shutdown_notify.clone();
95 let spawned_tasks = self.spawned_tasks.clone();
96 let is_shutdown = self.is_shutdown.clone();
97
98 log_d!(TAG, "Spawning task {}", tag);
99
100 let handle = self.get_handle()?.spawn(async move {
101 if is_shutdown.load(std::sync::atomic::Ordering::Relaxed) {
102 return;
103 }
104
105 let task_id = tokio::task::id();
106 log_d!(TAG, "Executing task {}.{}", tag_string, task_id);
107 task(shutdown_notify).await;
108 remove_join_handle_with_id(spawned_tasks, tag_string, &task_id);
109 });
110
111 Ok(self.insert_join_handle(tag, handle))
112 }
113
114 pub async fn await_tasks_with_tag(&self, tag: &str) {
115 let mut handles = Vec::new();
116
117 match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
118 Some(mut lock) => {
119 let keys: Vec<TaskId> = lock.keys().cloned().collect();
120 for key in &keys {
121 if key.tag == tag {
122 let removed = if let Some(handle) = lock.remove(key) {
123 handle
124 } else {
125 log_e!(TAG, "No running task found for tag {}", tag);
126 continue;
127 };
128
129 handles.push(removed);
130 }
131 }
132 }
133 None => {
134 log_e!(TAG, "Failed to lock spawned tasks for await_tasks_with_tag");
135 return;
136 }
137 };
138
139 join_all(handles).await;
140 }
141
142 pub async fn await_join_handle(
143 &self,
144 tag: &str,
145 handle_id: &tokio::task::Id,
146 ) -> Result<(), StatsigErr> {
147 let task_id = TaskId {
148 tag: tag.to_string(),
149 tokio_id: *handle_id,
150 };
151
152 let handle = match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
153 Some(mut lock) => match lock.remove(&task_id) {
154 Some(handle) => handle,
155 None => {
156 return Err(StatsigErr::ThreadFailure(
157 "No running task found".to_string(),
158 ));
159 }
160 },
161 None => {
162 log_e!(TAG, "Failed to lock spawned tasks for await_join_handle");
163 return Err(StatsigErr::ThreadFailure(
164 "Failed to lock spawned tasks".to_string(),
165 ));
166 }
167 };
168
169 handle
170 .await
171 .map_err(|e| StatsigErr::ThreadFailure(e.to_string()))?;
172
173 Ok(())
174 }
175
176 fn insert_join_handle(&self, tag: &str, handle: JoinHandle<()>) -> tokio::task::Id {
177 let handle_id = handle.id();
178 let task_id = TaskId {
179 tag: tag.to_string(),
180 tokio_id: handle_id,
181 };
182
183 match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
184 Some(mut lock) => {
185 lock.insert(task_id, handle);
186 }
187 None => {
188 log_e!(TAG, "Failed to lock spawned tasks for insert_join_handle");
189 }
190 }
191
192 handle_id
193 }
194}
195
196pub fn create_new_runtime() -> Runtime {
197 #[cfg(not(target_family = "wasm"))]
198 return Builder::new_multi_thread()
199 .worker_threads(5)
200 .thread_name("statsig")
201 .enable_all()
202 .build()
203 .expect("Failed to create a tokio Runtime");
204
205 #[cfg(target_family = "wasm")]
206 return Builder::new_current_thread()
207 .thread_name("statsig")
208 .enable_all()
209 .build()
210 .expect("Failed to create a tokio Runtime (single-threaded for wasm");
211}
212
213fn remove_join_handle_with_id(
214 spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
215 tag: String,
216 handle_id: &tokio::task::Id,
217) {
218 let task_id = TaskId {
219 tag,
220 tokio_id: *handle_id,
221 };
222
223 match spawned_tasks.try_lock_for(Duration::from_secs(5)) {
224 Some(mut lock) => {
225 lock.remove(&task_id);
226 }
227 None => {
228 log_e!(
229 TAG,
230 "Failed to lock spawned tasks for remove_join_handle_with_id"
231 );
232 }
233 }
234}
235
236fn create_runtime_if_required() {
237 if Handle::try_current().is_ok() {
238 log_d!(TAG, "External tokio runtime found");
239 return;
240 }
241
242 let global = StatsigGlobal::get();
243 let mut lock = global
244 .tokio_runtime
245 .try_lock_for(Duration::from_secs(5))
246 .expect("Failed to lock owned tokio runtime");
247
248 match lock.as_ref() {
249 Some(_) => {
250 log_d!(TAG, "Existing StatsigGlobal tokio runtime found");
251 }
252 None => {
253 log_d!(TAG, "Creating new tokio runtime for StatsigGlobal");
254 let rt = Arc::new(create_new_runtime());
255
256 lock.replace(rt);
257 }
258 };
259}
260
261impl Drop for StatsigRuntime {
262 fn drop(&mut self) {
263 self.shutdown();
264
265 }
300}