statsig_rust/
statsig_runtime.rs1use futures::future::join_all;
2use std::collections::HashMap;
3use std::future::Future;
4use std::sync::atomic::AtomicBool;
5use std::sync::{Arc, Mutex, Weak};
6use tokio::runtime::{Builder, Handle, Runtime};
7use tokio::sync::Notify;
8use tokio::task::JoinHandle;
9
10use crate::log_d;
11use crate::log_e;
12use crate::StatsigErr;
13
14const TAG: &str = stringify!(StatsigRuntime);
15
16lazy_static::lazy_static! {
17 static ref OWNED_TOKIO_RUNTIME: Mutex<Option<Weak<Runtime>>> = Mutex::new(None);
18}
19
20#[derive(Debug, Clone, PartialEq, Eq, Hash)]
21struct TaskId {
22 tag: String,
23 tokio_id: tokio::task::Id,
24}
25
26pub struct StatsigRuntime {
27 pub runtime_handle: Handle,
28 inner_runtime: Mutex<Option<Arc<Runtime>>>,
29 spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
30 shutdown_notify: Arc<Notify>,
31 is_shutdown: Arc<AtomicBool>,
32}
33
34impl StatsigRuntime {
35 #[must_use]
36 pub fn get_runtime() -> Arc<StatsigRuntime> {
37 let (opt_runtime, runtime_handle) = create_runtime_if_required();
38 let shutdown_notify = Notify::new();
39
40 Arc::new(StatsigRuntime {
41 inner_runtime: Mutex::new(opt_runtime),
42 runtime_handle,
43 spawned_tasks: Arc::new(Mutex::new(HashMap::new())),
44 shutdown_notify: Arc::new(shutdown_notify),
45 is_shutdown: Arc::new(AtomicBool::new(false)),
46 })
47 }
48
49 pub fn get_handle(&self) -> Handle {
50 self.runtime_handle.clone()
51 }
52
53 pub fn get_num_active_tasks(&self) -> usize {
54 match self.spawned_tasks.lock() {
55 Ok(lock) => lock.len(),
56 Err(e) => {
57 log_e!(TAG, "Failed to lock spawned tasks {}", e);
58 0
59 }
60 }
61 }
62
63 pub fn shutdown(&self) {
64 self.shutdown_notify.notify_waiters();
65
66 if let Ok(mut lock) = self.spawned_tasks.lock() {
67 for (_, task) in lock.drain() {
68 task.abort();
69 }
70 }
71 }
72
73 pub fn spawn<F, Fut>(&self, tag: &str, task: F) -> tokio::task::Id
74 where
75 F: FnOnce(Arc<Notify>) -> Fut + Send + 'static,
76 Fut: Future<Output = ()> + Send + 'static,
77 {
78 let tag_string = tag.to_string();
79 let shutdown_notify = self.shutdown_notify.clone();
80 let spawned_tasks = self.spawned_tasks.clone();
81 let is_shutdown = self.is_shutdown.clone();
82
83 log_d!(TAG, "Spawning task {}", tag);
84
85 let handle = self.runtime_handle.spawn(async move {
86 if is_shutdown.load(std::sync::atomic::Ordering::Relaxed) {
87 return;
88 }
89
90 let task_id = tokio::task::id();
91 log_d!(TAG, "Executing task {}.{}", tag_string, task_id);
92 task(shutdown_notify).await;
93 remove_join_handle_with_id(spawned_tasks, tag_string, &task_id);
94 });
95
96 self.insert_join_handle(tag, handle)
97 }
98
99 pub async fn await_tasks_with_tag(&self, tag: &str) {
100 let mut handles = Vec::new();
101
102 match self.spawned_tasks.lock() {
103 Ok(mut lock) => {
104 let keys: Vec<TaskId> = lock.keys().cloned().collect();
105 for key in &keys {
106 if key.tag == tag {
107 let removed = if let Some(handle) = lock.remove(key) {
108 handle
109 } else {
110 log_e!(TAG, "No running task found for tag {}", tag);
111 continue;
112 };
113
114 handles.push(removed);
115 }
116 }
117 }
118 Err(e) => {
119 log_e!(TAG, "Failed to lock spawned tasks {}", e);
120 return;
121 }
122 };
123
124 join_all(handles).await;
125 }
126
127 pub async fn await_join_handle(
128 &self,
129 tag: &str,
130 handle_id: &tokio::task::Id,
131 ) -> Result<(), StatsigErr> {
132 let task_id = TaskId {
133 tag: tag.to_string(),
134 tokio_id: *handle_id,
135 };
136
137 let handle = match self.spawned_tasks.lock() {
138 Ok(mut lock) => match lock.remove(&task_id) {
139 Some(handle) => handle,
140 None => {
141 return Err(StatsigErr::ThreadFailure(
142 "No running task found".to_string(),
143 ));
144 }
145 },
146 Err(e) => {
147 log_e!(
148 TAG,
149 "An error occurred while getting join handle with id: {}: {}",
150 handle_id,
151 e.to_string()
152 );
153 return Err(StatsigErr::ThreadFailure(e.to_string()));
154 }
155 };
156
157 handle
158 .await
159 .map_err(|e| StatsigErr::ThreadFailure(e.to_string()))?;
160
161 Ok(())
162 }
163
164 fn insert_join_handle(&self, tag: &str, handle: JoinHandle<()>) -> tokio::task::Id {
165 let handle_id = handle.id();
166 let task_id = TaskId {
167 tag: tag.to_string(),
168 tokio_id: handle_id,
169 };
170
171 match self.spawned_tasks.lock() {
172 Ok(mut lock) => {
173 lock.insert(task_id, handle);
174 }
175 Err(e) => {
176 log_e!(
177 TAG,
178 "An error occurred while inserting join handle: {}",
179 e.to_string()
180 );
181 }
182 }
183
184 handle_id
185 }
186}
187
188fn remove_join_handle_with_id(
189 spawned_tasks: Arc<Mutex<HashMap<TaskId, JoinHandle<()>>>>,
190 tag: String,
191 handle_id: &tokio::task::Id,
192) {
193 let task_id = TaskId {
194 tag,
195 tokio_id: *handle_id,
196 };
197
198 match spawned_tasks.lock() {
199 Ok(mut lock) => {
200 lock.remove(&task_id);
201 }
202 Err(e) => {
203 log_e!(
204 TAG,
205 "An error occurred while removing join handle {}",
206 e.to_string()
207 );
208 }
209 }
210}
211
212fn create_runtime_if_required() -> (Option<Arc<Runtime>>, Handle) {
213 if let Ok(handle) = Handle::try_current() {
214 log_d!(TAG, "Existing tokio runtime found");
215 return (None, handle);
216 }
217
218 let mut lock = OWNED_TOKIO_RUNTIME
219 .lock()
220 .expect("Failed to lock owned tokio runtime");
221
222 match lock.as_ref().and_then(|rt| rt.upgrade()) {
223 Some(rt) => (Some(rt.clone()), rt.handle().clone()),
224 None => {
225 let rt = Arc::new(
226 Builder::new_multi_thread()
227 .worker_threads(5)
228 .thread_name("statsig")
229 .enable_all()
230 .build()
231 .expect("Failed to find or create a tokio Runtime"),
232 );
233
234 let handle = rt.handle().clone();
235 lock.replace(Arc::downgrade(&rt));
236 (Some(rt), handle)
237 }
238 }
239}
240
241impl Drop for StatsigRuntime {
242 fn drop(&mut self) {
243 self.shutdown();
244
245 match self.inner_runtime.lock() {
246 Ok(mut inner_runtime) => {
247 let _ = inner_runtime.take();
248 }
249 Err(e) => {
250 log_e!(TAG, "Failed to lock inner runtime {}", e);
251 }
252 }
253 }
254}