statsig_rust/
statsig_runtime.rs1use crate::statsig_global::StatsigGlobal;
2use crate::StatsigErr;
3use crate::{log_d, log_e};
4use futures::future::join_all;
5use parking_lot::Mutex;
6use std::collections::HashMap;
7use std::future::Future;
8use std::sync::atomic::AtomicBool;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::runtime::{Builder, Handle, Runtime};
12use tokio::sync::Notify;
13use tokio::task::JoinHandle;
14
15const TAG: &str = stringify!(StatsigRuntime);
16
17#[derive(Debug, Clone, PartialEq, Eq, Hash)]
18struct TaskId {
19 tag: String,
20 tokio_id: tokio::task::Id,
21}
22
23pub struct StatsigRuntime {
24 spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
25 shutdown_notify: Arc<Notify>,
26 is_shutdown: Arc<AtomicBool>,
27}
28
29impl StatsigRuntime {
30 #[must_use]
31 pub fn get_runtime() -> Arc<StatsigRuntime> {
32 create_runtime_if_required();
33
34 Arc::new(StatsigRuntime {
35 spawned_tasks: Arc::new(Mutex::new(HashMap::new())),
36 shutdown_notify: Arc::new(Notify::new()),
37 is_shutdown: Arc::new(AtomicBool::new(false)),
38 })
39 }
40
41 pub fn get_handle(&self) -> Result<Handle, StatsigErr> {
42 if let Ok(handle) = Handle::try_current() {
43 return Ok(handle);
44 }
45
46 let global = StatsigGlobal::get();
47 let rt = global
48 .tokio_runtime
49 .try_lock_for(Duration::from_secs(5))
50 .ok_or_else(|| StatsigErr::LockFailure("Failed to lock tokio runtime".to_string()))?;
51
52 if let Some(rt) = rt.as_ref() {
53 return Ok(rt.handle().clone());
54 }
55
56 Err(StatsigErr::ThreadFailure(
57 "No tokio runtime found".to_string(),
58 ))
59 }
60
61 pub fn get_num_active_tasks(&self) -> usize {
62 match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
63 Some(lock) => lock.len(),
64 None => {
65 log_e!(TAG, "Failed to lock spawned tasks for get_num_active_tasks");
66 0
67 }
68 }
69 }
70
71 pub fn shutdown(&self) {
72 self.shutdown_notify.notify_waiters();
73
74 match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
75 Some(mut lock) => {
76 for (_, task) in lock.drain() {
77 task.abort();
78 }
79 }
80 None => {
81 log_e!(TAG, "Failed to lock spawned tasks for shutdown");
82 }
83 }
84 }
85
86 pub fn spawn<F, Fut>(&self, tag: &str, task: F) -> Result<tokio::task::Id, StatsigErr>
87 where
88 F: FnOnce(Arc<Notify>) -> Fut + Send + 'static,
89 Fut: Future<Output = ()> + Send + 'static,
90 {
91 let tag_string = tag.to_string();
92 let shutdown_notify = self.shutdown_notify.clone();
93 let spawned_tasks = self.spawned_tasks.clone();
94 let is_shutdown = self.is_shutdown.clone();
95
96 log_d!(TAG, "Spawning task {}", tag);
97
98 let handle = self.get_handle()?.spawn(async move {
99 if is_shutdown.load(std::sync::atomic::Ordering::Relaxed) {
100 return;
101 }
102
103 let task_id = tokio::task::id();
104 log_d!(TAG, "Executing task {}.{}", tag_string, task_id);
105 task(shutdown_notify).await;
106 remove_join_handle_with_id(spawned_tasks, tag_string, &task_id);
107 });
108
109 Ok(self.insert_join_handle(tag, handle))
110 }
111
112 pub async fn await_tasks_with_tag(&self, tag: &str) {
113 let mut handles = Vec::new();
114
115 match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
116 Some(mut lock) => {
117 let keys: Vec<TaskId> = lock.keys().cloned().collect();
118 for key in &keys {
119 if key.tag == tag {
120 let removed = if let Some(handle) = lock.remove(key) {
121 handle
122 } else {
123 log_e!(TAG, "No running task found for tag {}", tag);
124 continue;
125 };
126
127 handles.push(removed);
128 }
129 }
130 }
131 None => {
132 log_e!(TAG, "Failed to lock spawned tasks for await_tasks_with_tag");
133 return;
134 }
135 };
136
137 join_all(handles).await;
138 }
139
140 pub async fn await_join_handle(
141 &self,
142 tag: &str,
143 handle_id: &tokio::task::Id,
144 ) -> Result<(), StatsigErr> {
145 let task_id = TaskId {
146 tag: tag.to_string(),
147 tokio_id: *handle_id,
148 };
149
150 let handle = match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
151 Some(mut lock) => match lock.remove(&task_id) {
152 Some(handle) => handle,
153 None => {
154 return Err(StatsigErr::ThreadFailure(
155 "No running task found".to_string(),
156 ));
157 }
158 },
159 None => {
160 log_e!(TAG, "Failed to lock spawned tasks for await_join_handle");
161 return Err(StatsigErr::ThreadFailure(
162 "Failed to lock spawned tasks".to_string(),
163 ));
164 }
165 };
166
167 handle
168 .await
169 .map_err(|e| StatsigErr::ThreadFailure(e.to_string()))?;
170
171 Ok(())
172 }
173
174 fn insert_join_handle(&self, tag: &str, handle: JoinHandle<()>) -> tokio::task::Id {
175 let handle_id = handle.id();
176 let task_id = TaskId {
177 tag: tag.to_string(),
178 tokio_id: handle_id,
179 };
180
181 match self.spawned_tasks.try_lock_for(Duration::from_secs(5)) {
182 Some(mut lock) => {
183 lock.insert(task_id, handle);
184 }
185 None => {
186 log_e!(TAG, "Failed to lock spawned tasks for insert_join_handle");
187 }
188 }
189
190 handle_id
191 }
192}
193
194pub fn create_new_runtime() -> Runtime {
195 #[cfg(not(target_family = "wasm"))]
196 return Builder::new_multi_thread()
197 .worker_threads(5)
198 .thread_name("statsig")
199 .enable_all()
200 .build()
201 .expect("Failed to create a tokio Runtime");
202
203 #[cfg(target_family = "wasm")]
204 return Builder::new_current_thread()
205 .thread_name("statsig")
206 .enable_all()
207 .build()
208 .expect("Failed to create a tokio Runtime (single-threaded for wasm");
209}
210
211fn remove_join_handle_with_id(
212 spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
213 tag: String,
214 handle_id: &tokio::task::Id,
215) {
216 let task_id = TaskId {
217 tag,
218 tokio_id: *handle_id,
219 };
220
221 match spawned_tasks.try_lock_for(Duration::from_secs(5)) {
222 Some(mut lock) => {
223 lock.remove(&task_id);
224 }
225 None => {
226 log_e!(
227 TAG,
228 "Failed to lock spawned tasks for remove_join_handle_with_id"
229 );
230 }
231 }
232}
233
234fn create_runtime_if_required() {
235 if Handle::try_current().is_ok() {
236 log_d!(TAG, "External tokio runtime found");
237 return;
238 }
239
240 let global = StatsigGlobal::get();
241 let mut lock = global
242 .tokio_runtime
243 .try_lock_for(Duration::from_secs(5))
244 .expect("Failed to lock owned tokio runtime");
245
246 match lock.as_ref() {
247 Some(_) => {
248 log_d!(TAG, "Existing StatsigGlobal tokio runtime found");
249 }
250 None => {
251 log_d!(TAG, "Creating new tokio runtime for StatsigGlobal");
252 let rt = Arc::new(create_new_runtime());
253
254 lock.replace(rt);
255 }
256 };
257}
258
259impl Drop for StatsigRuntime {
260 fn drop(&mut self) {
261 self.shutdown();
262
263 }
298}